Adding Active Label Cleaning code (#559)
* initial commit * updating the build * flake8 * update main page * add the links * try to fix the env * update build, gitignore and remove duplicate license * update gitignore again * Adding to changelog * conda activate * update again * wrong instruction * add data quality * rephrase * first pass on Readme.md * switch from our to the, and clarify the cxr datasets * move content to a separate markdown file * move additional content to config readme file * finish updating dataquality readme * Rename * pr ocmment * todos * changed default dir for cifar10 dataset Co-authored-by: Ozan Oktay <ozan.oktay@microsoft.com>
This commit is contained in:
Родитель
521c004357
Коммит
94553a5c0b
|
@ -158,4 +158,11 @@ Tests/ML/test_outputs
|
|||
# This file contains the run recovery ID of the most recent job
|
||||
most_recent_run.txt
|
||||
# The default folder that contains downloaded Tensorboard files
|
||||
tensorboard_runs
|
||||
tensorboard_runs
|
||||
|
||||
# InnerEye-DataQuality
|
||||
InnerEye-DataQuality/cifar-10-python.tar.gz
|
||||
InnerEye-DataQuality/name_stats_scoring.png
|
||||
InnerEye-DataQuality/cifar-10-batches-py
|
||||
InnerEye-DataQuality/logs
|
||||
InnerEye-DataQuality/data
|
||||
|
|
|
@ -22,6 +22,7 @@ jobs that run in AzureML.
|
|||
ensemble) using the parameter `model_id`.
|
||||
- ([#554](https://github.com/microsoft/InnerEye-DeepLearning/pull/554)) Added a parameter `pretraining_dataset_id` to
|
||||
`NIH_COVID_BYOL` to specify the name of the SSL training dataset.
|
||||
- ([#559](https://github.com/microsoft/InnerEye-DeepLearning/pull/559)) Adding the accompanying code for the ["Active label cleaning: Improving dataset quality under resource constraints"](https://arxiv.org/abs/2109.00574) paper. The code can be found in the [InnerEye-DataQuality](InnerEye-DataQuality/README.md) subfolder. It provides tools for training noise robust models, running label cleaning simulation and loading our label cleaning benchmark datasets.
|
||||
|
||||
### Changed
|
||||
- ([#531](https://github.com/microsoft/InnerEye-DeepLearning/pull/531)) Updated PL to 1.3.8, torchmetrics and pl-bolts and changed relevant metrics and SSL code API.
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy.special import softmax
|
||||
from sklearn.neighbors import kneighbors_graph
|
||||
|
||||
|
||||
@dataclass(init=True, frozen=True)
|
||||
class GraphParameters:
|
||||
"""Class for setting graph connectivity and diffusion parameters."""
|
||||
n_neighbors: int
|
||||
diffusion_alpha: float
|
||||
cg_solver_max_iter: int
|
||||
diffusion_batch_size: Optional[int]
|
||||
distance_kernel: str # {'euclidean' or 'cosine'}
|
||||
|
||||
|
||||
def _get_affinity_matrix(embeddings: np.ndarray,
|
||||
n_neighbors: int,
|
||||
distance_kernel: str = 'cosine') -> scipy.sparse.csr.csr_matrix:
|
||||
"""
|
||||
:param embeddings: Input sample embeddings (n_samples x n_embedding_dim)
|
||||
:param n_neighbors: Number of neighbors in the KNN Graph
|
||||
:param distance_kernel: Distance kernel to compute sample similarities {'euclidean' or 'cosine'}
|
||||
"""
|
||||
|
||||
# Build a k-NN graph using the embeddings
|
||||
if distance_kernel == 'euclidean':
|
||||
sigma = embeddings.shape[1]
|
||||
knn_dist_graph = kneighbors_graph(embeddings, n_neighbors, mode='distance', metric='euclidean', n_jobs=-1)
|
||||
knn_dist_graph.data = np.exp(-1.0 * np.asarray(knn_dist_graph.data) ** 2 / (2.0 * sigma ** 2))
|
||||
elif distance_kernel == 'cosine':
|
||||
knn_dist_graph = kneighbors_graph(embeddings, n_neighbors, mode='distance', metric='cosine', n_jobs=-1)
|
||||
knn_dist_graph.data = 1.0 - np.asarray(knn_dist_graph.data) / 2.0
|
||||
else:
|
||||
raise ValueError(f"Unknown sample distance kernel {distance_kernel}")
|
||||
|
||||
return knn_dist_graph
|
||||
|
||||
|
||||
def build_connectivity_graph(normalised: bool = True, **affinity_kwargs: Any) -> np.ndarray:
|
||||
"""
|
||||
Builds connectivity graph and returns adjacency and degree matrix
|
||||
:param normalised: If set to True, graph adjacency is normalised with the norm of degree matrix
|
||||
:param affinity_kwargs: Arguments required to construct an affinity matrix
|
||||
(weights representing similarity between points)
|
||||
"""
|
||||
|
||||
# Build a symmetric adjacency matrix
|
||||
A = _get_affinity_matrix(**affinity_kwargs)
|
||||
W = 0.5 * (A + A.T)
|
||||
if normalised:
|
||||
# Normalize the similarity graph
|
||||
W = W - scipy.sparse.diags(W.diagonal())
|
||||
D = W.sum(axis=1)
|
||||
D[D == 0] = 1
|
||||
D_sqrt_inv = np.array(1. / np.sqrt(D))
|
||||
D_sqrt_inv = scipy.sparse.diags(D_sqrt_inv.reshape(-1))
|
||||
L_norm = D_sqrt_inv * W * D_sqrt_inv
|
||||
return L_norm
|
||||
else:
|
||||
num_samples = W.shape[0]
|
||||
D = W.sum(axis=1)
|
||||
D = np.diag(np.asarray(D).reshape(num_samples, ))
|
||||
L = D - W
|
||||
return L
|
||||
|
||||
|
||||
def label_diffusion(inv_laplacian: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
query_batch_ids: np.ndarray,
|
||||
class_priors: Optional[np.ndarray] = None,
|
||||
diffusion_normalizing_factor: float = 0.01) -> np.ndarray:
|
||||
"""
|
||||
:param laplacian_inv: inverse laplacian of the graph
|
||||
:param labels:
|
||||
:param query_batch_ids: the ids of the "labeled" samples
|
||||
:param class_priors: prior distribution of each class [n_classes,]
|
||||
:param diffusion_normalizing_factor: factor to normalize the diffused labels
|
||||
"""
|
||||
diffusion_start = time.time()
|
||||
|
||||
# Input number of nodes and classes
|
||||
n_samples = labels.shape[0]
|
||||
n_classes = labels.shape[1]
|
||||
|
||||
# Find the labelled set of nodes in the graph
|
||||
all_idx = np.array(range(n_samples))
|
||||
labeled_idx = np.setdiff1d(all_idx, query_batch_ids.flatten())
|
||||
assert (np.all(np.isin(query_batch_ids, all_idx)))
|
||||
assert (np.allclose(np.sum(labels, axis=1), np.ones(shape=(labels.shape[0])), rtol=1e-3))
|
||||
|
||||
# Initialize the y vector for each class (eq 5 from the paper, normalized with the class size)
|
||||
# and apply label propagation
|
||||
y = np.zeros((n_samples, n_classes))
|
||||
y[labeled_idx] = labels[labeled_idx] / np.sum(labels[labeled_idx], axis=0, keepdims=True)
|
||||
if class_priors is not None:
|
||||
y = y * class_priors
|
||||
Z = np.matmul(inv_laplacian[query_batch_ids, :], y)
|
||||
|
||||
# Normalise the diffused logits
|
||||
output = softmax(Z / diffusion_normalizing_factor, axis=1)
|
||||
# output = Z / Z.sum(axis=1)
|
||||
logging.debug(f"Graph diffusion time: {0: .2f} secs".format(time.time() - diffusion_start))
|
||||
|
||||
return output
|
|
@ -0,0 +1,60 @@
|
|||
## Config arguments for model training:
|
||||
All possible config arguments are defined in [model_config.py](InnerEyeDataQuality/configs/models/model_config.py). Here you will find a summary of the most important config arguments:
|
||||
* If you want to train a model with co_teaching, you will need to set `train.use_co_teaching: True` in your config.
|
||||
* If you want to finetune from a pretrained SSL checkpoint:
|
||||
* You will need to set `train.use_self_supervision: True` to tell the code to load a pretrained checkpoint.
|
||||
* You will need update the `train.self_supervision.checkpoints: [PATH_TO_SSL]` field with the checkpoints to use for initialization of your model. Note that if you want to train a co-teaching model in combination with SSL pretrained initialization your list of checkpoint needs to be of length 2.
|
||||
* You can also choose whether to freeze the encoder or not during finetuning with `train.self_supervision.freeze_encoder` field.
|
||||
* If you want to train a model using ELR, you can set `train.use_elr: True`
|
||||
|
||||
### CIFAR10H
|
||||
We provide configurations to run experiments on CIFAR10H with resp. 15% and 30% noise rate.
|
||||
* Configs for 15% noise rate experiments can be found in [configs/models/cifar10h_noise_15](InnerEyeDataQuality/configs/models/cifar10h_noise_15). In detail this folder contains configs for
|
||||
* vanilla resnet training: [InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet.yaml)
|
||||
* co-teaching resnet training: [InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_co_teaching.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_co_teaching.yaml)
|
||||
* SSL + linear head training: [InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_self_supervision_v3.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_self_supervision_v3.yaml)
|
||||
* Configs for 30% noise rate experiments can be found in [configs/models/cifar10h_noise_30](InnerEyeDataQuality/configs/models/cifar10h_noise_30). In detail this folder contains configs for:
|
||||
* vanilla resnet training: [InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet.yaml)
|
||||
* co-teaching resnet training: [InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_co_teaching.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_co_teaching.yaml)
|
||||
* SSL + linear head training: [InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_self_supervision_v3.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_self_supervision_v3.yaml)
|
||||
* ELR training: [InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_elr.yaml](InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_elr.yaml)
|
||||
* Examples of configs for models used in the model selection benchmark experiment can be found in the [configs/models/benchmark_3_idn](InnerEyeDataQuality/configs/models/benchmark_3_idn)
|
||||
|
||||
### Noisy Chest-Xray
|
||||
*Note:* To run any model on this dataset, you will need to first make sure you have the dataset ready onto your machine (see Benchmark datasets > Noisy Chest-Xray > Pre-requisite section).
|
||||
|
||||
* With provide configurations corresponding to the experiments on the NoisyChestXray dataset with 13% noise rate in the [configs/models/cxr](InnerEyeDataQuality/configs/models/cxr) folder:
|
||||
* Vanilla resnet training: [InnerEyeDataQuality/configs/models/cxr/resnet.yaml](InnerEyeDataQuality/configs/models/cxr/resnet.yaml)
|
||||
* Co-teaching resnet training: [InnerEyeDataQuality/configs/models/cxr/resnet_coteaching.yaml](InnerEyeDataQuality/configs/models/cxr/resnet_coteaching.yaml)
|
||||
* Co-teaching from a pretrained SSL checkpoint: [InnerEyeDataQuality/configs/models/cxr/resnet_ssl_coteaching.yaml]([InnerEyeDataQuality/configs/models/cxr/resnet_ssl_coteaching.yaml])
|
||||
<br/><br/>
|
||||
|
||||
## Config arguments for label cleaning simulation:
|
||||
|
||||
#### More details about the selector config
|
||||
Here is an example of a selector config, with details about each field:
|
||||
|
||||
* `selector:`
|
||||
* `type`: Which selector to use. There are several options available:
|
||||
* `PosteriorBasedSelectorJoint`: Using the ranking function proposed in the manuscript CE(posteriors, labels) - H(posteriors)
|
||||
* `PosteriorBasedSelector`: Using CE(posteriors, labels) as the ranking function
|
||||
* `GraphBasedSelector`: Graph based selection of the next samples based on the embeddings of each sample.
|
||||
* `BaldSelector`: Selection of the next sample with the BALD objective.
|
||||
* `model_name`: The name that will be used for the legend of the simulation plot
|
||||
* `model_config_path`: Path to the config file used to train the selection model.
|
||||
* `use_active_relabelling`: Whether to turn on the active component of the active learning framework. If set to True, the model will be retrained regularly during the selection process.
|
||||
* `output_directory`: Optional field where can specify the output directory to store the results in.
|
||||
|
||||
|
||||
#### Off-the-shelf simulation configs
|
||||
* We provide the configs for various selectors for the CIFAR10H experiments in the [configs/selection/cifar10h_noise_15](InnerEyeDataQuality/configs/selection/cifar10h_noise_15) and in the [configs/selection/cifar10h_noise_30](InnerEyeDataQuality/configs/selection/cifar10h_noise_30) folders.
|
||||
* And likewise for the NoisyChestXray dataset, you can find a set of selector configs in the [configs/selection/cxr](InnerEyeDataQuality/configs/selection/cxr) folder.
|
||||
<br/><br/>
|
||||
|
||||
## Configs for self-supervised model training:
|
||||
|
||||
CIFAR10H: To pretrain embeddings with contrastive learning on CIFAR10H you can use the
|
||||
[cifar10h_byol.yaml](InnerEyeDataQuality/deep_learning/self_supervised/configs/cifar10h_byol.yaml) or the [cifar10h_simclr.yaml](InnerEyeDataQuality/deep_learning/self_supervised/configs/cifar10h_simclr.yaml) config files.
|
||||
|
||||
Chest X-ray: Provided that you have downloaded the dataset as explained in the Benchmark Datasets > Other Chest Xray Datasets > NIH Datasets > Pre-requisites section, you can easily launch a unsupervised pretraining run on the full NIH dataset using the [nih_byol.yaml](InnerEyeDataQuality/deep_learning/self_supervised/configs/nih_byol.yaml) or the [nih_simclr.yaml](InnerEyeDataQuality/deep_learning/self_supervised/configs/nih_simclr.yaml)
|
||||
configs. Don't forget to update the `dataset_dir` field of your config to reflect your local path.
|
|
@ -0,0 +1,48 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import yacs.config
|
||||
|
||||
|
||||
class ConfigNode(yacs.config.CfgNode):
|
||||
def __init__(self, init_dict: Optional[Dict] = None, key_list: Optional[List] = None, new_allowed: bool = False):
|
||||
super().__init__(init_dict, key_list, new_allowed)
|
||||
|
||||
def __str__(self) -> str:
|
||||
def _indent(s_: str, num_spaces: int) -> Union[str, List[str]]:
|
||||
s = s_.split('\n')
|
||||
if len(s) == 1:
|
||||
return s_
|
||||
first = s.pop(0)
|
||||
s = [(num_spaces * ' ') + line for line in s]
|
||||
s = '\n'.join(s) # type: ignore
|
||||
s = first + '\n' + s # type: ignore
|
||||
return s
|
||||
|
||||
r = ''
|
||||
s = []
|
||||
for k, v in self.items():
|
||||
separator = '\n' if isinstance(v, ConfigNode) else ' '
|
||||
if isinstance(v, str) and not v:
|
||||
v = '\'\''
|
||||
attr_str = f'{str(k)}:{separator}{str(v)}'
|
||||
attr_str = _indent(attr_str, 2) # type: ignore
|
||||
s.append(attr_str)
|
||||
r += '\n'.join(s)
|
||||
return r
|
||||
|
||||
def as_dict(self) -> Dict:
|
||||
def convert_to_dict(node: ConfigNode) -> Dict:
|
||||
if not isinstance(node, ConfigNode):
|
||||
return node
|
||||
else:
|
||||
dic = dict()
|
||||
for k, v in node.items():
|
||||
dic[k] = convert_to_dict(v)
|
||||
return dic
|
||||
|
||||
return convert_to_dict(self)
|
|
@ -0,0 +1,205 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from .config_node import ConfigNode
|
||||
|
||||
config = ConfigNode()
|
||||
|
||||
config.pretty_name = ""
|
||||
config.device = 'cuda'
|
||||
|
||||
# cuDNN
|
||||
config.cudnn = ConfigNode()
|
||||
config.cudnn.benchmark = True
|
||||
config.cudnn.deterministic = False
|
||||
|
||||
config.dataset = ConfigNode()
|
||||
config.dataset.name = 'CIFAR10'
|
||||
config.dataset.dataset_dir = ''
|
||||
config.dataset.image_size = 32
|
||||
config.dataset.n_channels = 3
|
||||
config.dataset.n_classes = 10
|
||||
config.dataset.num_samples = None
|
||||
config.dataset.noise_temperature = 1.0
|
||||
config.dataset.noise_offset = 0.0
|
||||
config.dataset.noise_rate = None
|
||||
config.dataset.csv_to_ignore = None
|
||||
config.dataset.cxr_consolidation_noise_rate = 0.10
|
||||
|
||||
config.model = ConfigNode()
|
||||
# options: 'cifar', 'imagenet'
|
||||
# Use 'cifar' for small input images
|
||||
config.model.type = 'cifar'
|
||||
config.model.name = 'resnet_preact'
|
||||
config.model.init_mode = 'kaiming_fan_out'
|
||||
config.model.use_dropout = False
|
||||
|
||||
config.model.resnet = ConfigNode()
|
||||
config.model.resnet.depth = 110 # for cifar type model
|
||||
config.model.resnet.n_blocks = [2, 2, 2, 2] # for imagenet type model
|
||||
config.model.resnet.block_type = 'basic'
|
||||
config.model.resnet.initial_channels = 16
|
||||
config.model.resnet.apply_l2_norm = False # if set to True, last activations are l2-normalised.
|
||||
|
||||
config.model.densenet = ConfigNode()
|
||||
config.model.densenet.depth = 100 # for cifar type model
|
||||
config.model.densenet.n_blocks = [6, 12, 24, 16] # for imagenet type model
|
||||
config.model.densenet.block_type = 'bottleneck'
|
||||
config.model.densenet.growth_rate = 12
|
||||
config.model.densenet.drop_rate = 0.0
|
||||
config.model.densenet.compression_rate = 0.5
|
||||
|
||||
config.train = ConfigNode()
|
||||
config.train.root_dir = ''
|
||||
config.train.checkpoint = ''
|
||||
config.train.resume_epoch = 0
|
||||
config.train.restore_scheduler = True
|
||||
config.train.batch_size = 128
|
||||
# optimizer (options: sgd, adam)
|
||||
config.train.optimizer = 'sgd'
|
||||
config.train.base_lr = 0.1
|
||||
config.train.momentum = 0.9
|
||||
config.train.nesterov = True
|
||||
config.train.weight_decay = 1e-4
|
||||
config.train.start_epoch = 0
|
||||
config.train.seed = 0
|
||||
config.train.pretrained = False
|
||||
config.train.no_weight_decay_on_bn = False
|
||||
config.train.use_balanced_sampler = False
|
||||
|
||||
config.train.output_dir = 'experiments/exp00'
|
||||
config.train.log_period = 100
|
||||
config.train.checkpoint_period = 10
|
||||
|
||||
config.train.use_elr = False
|
||||
|
||||
# co_teaching defaults
|
||||
config.train.use_co_teaching = False
|
||||
config.train.use_teacher_model = False
|
||||
config.train.co_teaching_consistency_loss = False
|
||||
config.train.co_teaching_forget_rate = 0.2
|
||||
config.train.co_teaching_num_gradual = 10
|
||||
config.train.co_teaching_use_graph = False
|
||||
config.train.co_teaching_num_warmup = 25
|
||||
|
||||
# self-supervision defaults
|
||||
config.train.use_self_supervision = False
|
||||
config.train.self_supervision = ConfigNode()
|
||||
config.train.self_supervision.checkpoints = ['', '']
|
||||
config.train.self_supervision.freeze_encoder = True
|
||||
config.train.tanh_regularisation = 0.0
|
||||
|
||||
# optimizer
|
||||
config.optim = ConfigNode()
|
||||
# Adam
|
||||
config.optim.adam = ConfigNode()
|
||||
config.optim.adam.betas = (0.9, 0.999)
|
||||
|
||||
# scheduler
|
||||
config.scheduler = ConfigNode()
|
||||
config.scheduler.epochs = 160
|
||||
|
||||
# warm up (options: none, linear, exponential)
|
||||
config.scheduler.warmup = ConfigNode()
|
||||
config.scheduler.warmup.type = 'none'
|
||||
config.scheduler.warmup.epochs = 0
|
||||
config.scheduler.warmup.start_factor = 1e-3
|
||||
config.scheduler.warmup.exponent = 4
|
||||
|
||||
# main scheduler (options: constant, linear, multistep, cosine)
|
||||
config.scheduler.type = 'multistep'
|
||||
config.scheduler.milestones = [80, 120]
|
||||
config.scheduler.lr_decay = 0.1
|
||||
config.scheduler.lr_min_factor = 0.001
|
||||
|
||||
# tensorboard
|
||||
config.tensorboard = ConfigNode()
|
||||
config.tensorboard.save_events = True
|
||||
|
||||
# train data loader
|
||||
config.train.dataloader = ConfigNode()
|
||||
config.train.dataloader.num_workers = 2
|
||||
config.train.dataloader.drop_last = True
|
||||
config.train.dataloader.pin_memory = False
|
||||
config.train.dataloader.non_blocking = False
|
||||
|
||||
# validation data loader
|
||||
config.validation = ConfigNode()
|
||||
config.validation.batch_size = 256
|
||||
config.validation.dataloader = ConfigNode()
|
||||
config.validation.dataloader.num_workers = 2
|
||||
config.validation.dataloader.drop_last = False
|
||||
config.validation.dataloader.pin_memory = False
|
||||
config.validation.dataloader.non_blocking = False
|
||||
|
||||
config.augmentation = ConfigNode()
|
||||
config.augmentation.use_random_crop = True
|
||||
config.augmentation.use_random_horizontal_flip = True
|
||||
config.augmentation.use_random_affine = False
|
||||
config.augmentation.use_label_smoothing = False
|
||||
config.augmentation.use_random_color = False
|
||||
config.augmentation.add_gaussian_noise = False
|
||||
config.augmentation.use_gamma_transform = False
|
||||
config.augmentation.use_random_erasing = False
|
||||
config.augmentation.use_elastic_transform = False
|
||||
|
||||
config.augmentation.elastic_transform = ConfigNode()
|
||||
config.augmentation.elastic_transform.sigma = 4
|
||||
config.augmentation.elastic_transform.alpha = 35
|
||||
config.augmentation.elastic_transform.p_apply = 0.5
|
||||
|
||||
config.augmentation.random_crop = ConfigNode()
|
||||
config.augmentation.random_crop.scale = (0.9, 1.0)
|
||||
|
||||
config.augmentation.gaussian_noise = ConfigNode()
|
||||
config.augmentation.gaussian_noise.std = 0.01
|
||||
config.augmentation.gaussian_noise.p_apply = 0.5
|
||||
|
||||
config.augmentation.random_horizontal_flip = ConfigNode()
|
||||
config.augmentation.random_horizontal_flip.prob = 0.5
|
||||
|
||||
config.augmentation.random_affine = ConfigNode()
|
||||
config.augmentation.random_affine.max_angle = 0
|
||||
config.augmentation.random_affine.max_horizontal_shift = 0.0
|
||||
config.augmentation.random_affine.max_vertical_shift = 0.0
|
||||
config.augmentation.random_affine.max_shear = 5
|
||||
|
||||
config.augmentation.random_color = ConfigNode()
|
||||
config.augmentation.random_color.brightness = 0.0
|
||||
config.augmentation.random_color.contrast = 0.1
|
||||
config.augmentation.random_color.saturation = 0.1
|
||||
|
||||
config.augmentation.gamma = ConfigNode()
|
||||
config.augmentation.gamma.scale = (0.5, 1.5)
|
||||
|
||||
config.augmentation.label_smoothing = ConfigNode()
|
||||
config.augmentation.label_smoothing.epsilon = 0.1
|
||||
|
||||
config.augmentation.random_erasing = ConfigNode()
|
||||
config.augmentation.random_erasing.scale = (0.01, 0.1)
|
||||
config.augmentation.random_erasing.ratio = (0.3, 3.3)
|
||||
|
||||
config.preprocess = ConfigNode()
|
||||
config.preprocess.use_resize = False
|
||||
config.preprocess.use_center_crop = False
|
||||
config.preprocess.center_crop_size = None
|
||||
config.preprocess.histogram_normalization = ConfigNode()
|
||||
config.preprocess.histogram_normalization.disk_size = 30
|
||||
config.preprocess.resize = 32
|
||||
|
||||
# test config
|
||||
config.test = ConfigNode()
|
||||
config.test.checkpoint = None
|
||||
config.test.output_dir = None
|
||||
config.test.batch_size = 256
|
||||
|
||||
# test data loader
|
||||
config.test.dataloader = ConfigNode()
|
||||
config.test.dataloader.num_workers = 2
|
||||
config.test.dataloader.pin_memory = False
|
||||
|
||||
|
||||
def get_default_model_config() -> ConfigNode:
|
||||
return config.clone()
|
|
@ -0,0 +1,75 @@
|
|||
device: cuda
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: CIFAR10IDN
|
||||
num_samples: 10000
|
||||
noise_rate: 0.4
|
||||
model:
|
||||
type: cifar
|
||||
name: resnet
|
||||
init_mode: kaiming_fan_out
|
||||
resnet:
|
||||
depth: 50
|
||||
initial_channels: 16
|
||||
block_type: basic
|
||||
apply_l2_norm: True
|
||||
train:
|
||||
resume_epoch: 0
|
||||
seed: 1
|
||||
batch_size: 256
|
||||
optimizer: sgd
|
||||
base_lr: 0.05
|
||||
momentum: 0.9
|
||||
nesterov: True
|
||||
weight_decay: 1e-4
|
||||
output_dir: experiments/benchmark_3_idn/co_teaching_v5_res50
|
||||
log_period: 100
|
||||
checkpoint_period: 100
|
||||
use_co_teaching: True
|
||||
co_teaching_forget_rate: 0.38
|
||||
co_teaching_num_gradual: 10
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
test:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 120
|
||||
type: multistep
|
||||
milestones: [70, 100]
|
||||
lr_decay: 0.1
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_label_smoothing: False
|
||||
use_random_affine: False
|
||||
use_random_color: True
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_affine:
|
||||
max_angle: 15
|
||||
max_horizontal_shift: 0.05
|
||||
max_vertical_shift: 0.05
|
||||
max_shear: 5
|
||||
random_color:
|
||||
brightness: 0.5
|
||||
contrast: 0.5
|
||||
saturation: 0.5
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelector']
|
||||
model_name: 'Self-Supervision'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/benchmark_3_idn/b3_ssl_cleaning_v2.yaml'
|
||||
output_directory: 'b3_ssl_cleaning_v2'
|
||||
tensorboard:
|
||||
save_events: False
|
|
@ -0,0 +1,8 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelector']
|
||||
model_name: 'Self-Supervision-Active'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/benchmark_3_idn/b3_ssl_cleaning_v2.yaml'
|
||||
output_directory: 'b3_ssl_cleaning_v2_active'
|
||||
use_active_relabelling: True
|
||||
tensorboard:
|
||||
save_events: False
|
|
@ -0,0 +1,73 @@
|
|||
device: cuda
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: CIFAR10IDN
|
||||
num_samples: 10000
|
||||
noise_rate: 0.4
|
||||
model:
|
||||
type: cifar
|
||||
name: resnet
|
||||
init_mode: kaiming_fan_out
|
||||
resnet:
|
||||
depth: 110
|
||||
initial_channels: 16
|
||||
block_type: basic
|
||||
apply_l2_norm: True
|
||||
train:
|
||||
resume_epoch: 0
|
||||
seed: 1
|
||||
batch_size: 256
|
||||
optimizer: sgd
|
||||
base_lr: 0.005
|
||||
momentum: 0.9
|
||||
nesterov: True
|
||||
weight_decay: 1e-3
|
||||
output_dir: experiments/benchmark_3_idn/ssl_cleaning_v2
|
||||
log_period: 100
|
||||
checkpoint_period: 100
|
||||
tanh_regularisation: 0.30
|
||||
use_self_supervision: True
|
||||
self_supervision:
|
||||
checkpoints: ["cifar10h/self_supervised/lightning_logs/BYOL_seed_1/checkpoints/last.ckpt"]
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
test:
|
||||
batch_size: 128
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 120
|
||||
type: multistep
|
||||
milestones: [90]
|
||||
lr_decay: 0.1
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_label_smoothing: True
|
||||
use_random_affine: False
|
||||
use_random_color: True
|
||||
random_crop:
|
||||
scale: (0.9, 1.0)
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_color:
|
||||
brightness: 0.5
|
||||
contrast: 0.5
|
||||
saturation: 0.5
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
device: cuda
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: CIFAR10IDN
|
||||
num_samples: 10000
|
||||
noise_rate: 0.4
|
||||
model:
|
||||
type: cifar
|
||||
name: resnet
|
||||
init_mode: kaiming_fan_out
|
||||
resnet:
|
||||
depth: 110
|
||||
initial_channels: 16
|
||||
block_type: basic
|
||||
apply_l2_norm: True
|
||||
train:
|
||||
resume_epoch: 0
|
||||
seed: 1
|
||||
batch_size: 256
|
||||
optimizer: sgd
|
||||
base_lr: 0.005
|
||||
momentum: 0.9
|
||||
nesterov: True
|
||||
weight_decay: 1e-4
|
||||
output_dir: experiments/benchmark_3_idn/ssl_v1
|
||||
log_period: 100
|
||||
checkpoint_period: 100
|
||||
tanh_regularisation: 0.00
|
||||
use_self_supervision: True
|
||||
self_supervision:
|
||||
checkpoints: ["cifar10h/self_supervised/lightning_logs/BYOL_seed_1/checkpoints/last.ckpt"]
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
test:
|
||||
batch_size: 128
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 120
|
||||
type: multistep
|
||||
milestones: [90]
|
||||
lr_decay: 0.1
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_label_smoothing: False
|
||||
use_random_affine: False
|
||||
use_random_color: True
|
||||
random_crop:
|
||||
scale: (0.9, 1.0)
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_color:
|
||||
brightness: 0.5
|
||||
contrast: 0.5
|
||||
saturation: 0.5
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
device: cuda
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: CIFAR10IDN
|
||||
num_samples: 10000
|
||||
noise_rate: 0.4
|
||||
model:
|
||||
type: cifar
|
||||
name: resnet
|
||||
init_mode: kaiming_fan_out
|
||||
resnet:
|
||||
depth: 50
|
||||
initial_channels: 16
|
||||
block_type: basic
|
||||
apply_l2_norm: True
|
||||
train:
|
||||
resume_epoch: 0
|
||||
seed: 1
|
||||
batch_size: 256
|
||||
optimizer: sgd
|
||||
base_lr: 0.05
|
||||
momentum: 0.9
|
||||
nesterov: True
|
||||
weight_decay: 1e-4
|
||||
output_dir: experiments/benchmark_3_idn/vanilla_model_res50
|
||||
log_period: 100
|
||||
checkpoint_period: 100
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
test:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 120
|
||||
type: multistep
|
||||
milestones: [70, 100]
|
||||
lr_decay: 0.1
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_label_smoothing: False
|
||||
use_random_affine: False
|
||||
use_random_color: True
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_affine:
|
||||
max_angle: 15
|
||||
max_horizontal_shift: 0.05
|
||||
max_vertical_shift: 0.05
|
||||
max_shear: 5
|
||||
random_color:
|
||||
brightness: 0.5
|
||||
contrast: 0.5
|
||||
saturation: 0.5
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
device: cuda
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: CIFAR10H
|
||||
noise_temperature: 2.0
|
||||
num_samples: 5000
|
||||
model:
|
||||
type: cifar
|
||||
name: resnet
|
||||
init_mode: kaiming_fan_out
|
||||
resnet:
|
||||
depth: 110
|
||||
initial_channels: 16
|
||||
block_type: basic
|
||||
apply_l2_norm: True
|
||||
train:
|
||||
resume_epoch: 0
|
||||
seed: 1
|
||||
batch_size: 256
|
||||
optimizer: sgd
|
||||
base_lr: 0.1
|
||||
momentum: 0.9
|
||||
nesterov: True
|
||||
weight_decay: 1e-4
|
||||
output_dir: cifar10h_noise_15/vanilla_model
|
||||
log_period: 100
|
||||
checkpoint_period: 100
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
test:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 160
|
||||
type: multistep
|
||||
milestones: [80, 120]
|
||||
lr_decay: 0.1
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_label_smoothing: False
|
||||
use_random_affine: True
|
||||
use_random_color: True
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_affine:
|
||||
max_angle: 15
|
||||
max_horizontal_shift: 0.05
|
||||
max_vertical_shift: 0.05
|
||||
max_shear: 5
|
||||
random_color:
|
||||
brightness: 0.5
|
||||
contrast: 0.5
|
||||
saturation: 0.5
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
device: cuda
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: CIFAR10H
|
||||
noise_temperature: 2.0
|
||||
num_samples: 5000
|
||||
model:
|
||||
type: cifar
|
||||
name: resnet
|
||||
init_mode: kaiming_fan_out
|
||||
resnet:
|
||||
depth: 110
|
||||
initial_channels: 16
|
||||
block_type: basic
|
||||
apply_l2_norm: True
|
||||
train:
|
||||
resume_epoch: 0
|
||||
seed: 1
|
||||
batch_size: 256
|
||||
optimizer: sgd
|
||||
base_lr: 0.1
|
||||
momentum: 0.9
|
||||
nesterov: True
|
||||
weight_decay: 1e-4
|
||||
output_dir: cifar10h_noise_15/co_teaching
|
||||
log_period: 100
|
||||
checkpoint_period: 100
|
||||
use_co_teaching: True
|
||||
co_teaching_forget_rate: 0.15
|
||||
co_teaching_num_gradual: 10
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 2048
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
test:
|
||||
batch_size: 128
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 160
|
||||
type: multistep
|
||||
milestones: [80, 120]
|
||||
lr_decay: 0.1
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_label_smoothing: False
|
||||
use_random_affine: True
|
||||
use_random_color: True
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_affine:
|
||||
max_angle: 15
|
||||
max_horizontal_shift: 0.05
|
||||
max_vertical_shift: 0.05
|
||||
max_shear: 5
|
||||
random_color:
|
||||
brightness: 0.5
|
||||
contrast: 0.5
|
||||
saturation: 0.5
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
device: cuda
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: CIFAR10H
|
||||
noise_temperature: 2.0
|
||||
num_samples: 5000
|
||||
train:
|
||||
resume_epoch: 0
|
||||
seed: 2
|
||||
batch_size: 256
|
||||
optimizer: sgd
|
||||
base_lr: 0.01
|
||||
momentum: 0.9
|
||||
nesterov: True
|
||||
weight_decay: 1e-3
|
||||
output_dir: experiments/cifar10h_noise_15/vanilla_self_supervision_v3
|
||||
log_period: 100
|
||||
checkpoint_period: 100
|
||||
tanh_regularisation: 0.30
|
||||
use_self_supervision: True
|
||||
self_supervision:
|
||||
checkpoints: ["cifar10h/self_supervised/lightning_logs/BYOL_seed_1/checkpoints/last.ckpt"]
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
test:
|
||||
batch_size: 128
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 120
|
||||
type: multistep
|
||||
milestones: [90]
|
||||
lr_decay: 0.1
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_label_smoothing: True
|
||||
use_random_affine: False
|
||||
use_random_color: True
|
||||
random_crop:
|
||||
scale: (0.9, 1.0)
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_color:
|
||||
brightness: 0.5
|
||||
contrast: 0.5
|
||||
saturation: 0.5
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
device: cuda
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: CIFAR10H
|
||||
noise_temperature: 10.0
|
||||
num_samples: 5000
|
||||
model:
|
||||
type: cifar
|
||||
name: resnet
|
||||
init_mode: kaiming_fan_out
|
||||
resnet:
|
||||
depth: 110
|
||||
initial_channels: 16
|
||||
block_type: basic
|
||||
apply_l2_norm: True
|
||||
train:
|
||||
resume_epoch: 0
|
||||
seed: 1
|
||||
batch_size: 256
|
||||
optimizer: sgd
|
||||
base_lr: 0.1
|
||||
momentum: 0.9
|
||||
nesterov: True
|
||||
weight_decay: 1e-4
|
||||
output_dir: experiments/cifar10h_noise_30/vanilla_model
|
||||
log_period: 100
|
||||
checkpoint_period: 100
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
test:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 160
|
||||
type: multistep
|
||||
milestones: [80, 120]
|
||||
lr_decay: 0.1
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_label_smoothing: False
|
||||
use_random_affine: True
|
||||
use_random_color: True
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_affine:
|
||||
max_angle: 15
|
||||
max_horizontal_shift: 0.05
|
||||
max_vertical_shift: 0.05
|
||||
max_shear: 5
|
||||
random_color:
|
||||
brightness: 0.5
|
||||
contrast: 0.5
|
||||
saturation: 0.5
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
device: cuda
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: CIFAR10H
|
||||
noise_temperature: 10.0
|
||||
num_samples: 5000
|
||||
model:
|
||||
type: cifar
|
||||
name: resnet
|
||||
init_mode: kaiming_fan_out
|
||||
resnet:
|
||||
depth: 110
|
||||
initial_channels: 16
|
||||
block_type: basic
|
||||
apply_l2_norm: True
|
||||
train:
|
||||
resume_epoch: 0
|
||||
seed: 1
|
||||
batch_size: 256
|
||||
optimizer: sgd
|
||||
base_lr: 0.1
|
||||
momentum: 0.9
|
||||
nesterov: True
|
||||
weight_decay: 1e-4
|
||||
output_dir: experiments/cifar10h_noise_30/co_teaching
|
||||
log_period: 100
|
||||
checkpoint_period: 100
|
||||
use_co_teaching: True
|
||||
co_teaching_forget_rate: 0.28
|
||||
co_teaching_num_gradual: 10
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 2048
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
test:
|
||||
batch_size: 128
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 160
|
||||
type: multistep
|
||||
milestones: [80, 120]
|
||||
lr_decay: 0.1
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_label_smoothing: False
|
||||
use_random_affine: True
|
||||
use_random_color: True
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_affine:
|
||||
max_angle: 15
|
||||
max_horizontal_shift: 0.05
|
||||
max_vertical_shift: 0.05
|
||||
max_shear: 5
|
||||
random_color:
|
||||
brightness: 0.5
|
||||
contrast: 0.5
|
||||
saturation: 0.5
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
device: cuda:2
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: CIFAR10H
|
||||
noise_temperature: 10.0
|
||||
num_samples: 5000
|
||||
model:
|
||||
type: cifar
|
||||
name: resnet
|
||||
init_mode: kaiming_fan_out
|
||||
use_dropout: False
|
||||
resnet:
|
||||
depth: 110
|
||||
initial_channels: 16
|
||||
block_type: basic
|
||||
apply_l2_norm: True
|
||||
train:
|
||||
use_elr: True
|
||||
resume_epoch: 0
|
||||
seed: 0
|
||||
batch_size: 256
|
||||
optimizer: sgd
|
||||
base_lr: 0.1
|
||||
momentum: 0.9
|
||||
nesterov: True
|
||||
weight_decay: 1e-4
|
||||
output_dir: experiments/cifar10h_noise_30/elr
|
||||
log_period: 100
|
||||
checkpoint_period: 100
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
test:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 160
|
||||
type: multistep
|
||||
milestones: [80, 120]
|
||||
lr_decay: 0.1
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_label_smoothing: False
|
||||
use_random_affine: True
|
||||
use_random_color: True
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_affine:
|
||||
max_angle: 15
|
||||
max_horizontal_shift: 0.05
|
||||
max_vertical_shift: 0.05
|
||||
max_shear: 5
|
||||
random_color:
|
||||
brightness: 0.5
|
||||
contrast: 0.5
|
||||
saturation: 0.5
|
|
@ -0,0 +1,64 @@
|
|||
device: cuda
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: CIFAR10H
|
||||
noise_temperature: 10.0
|
||||
num_samples: 5000
|
||||
train:
|
||||
resume_epoch: 0
|
||||
seed: 2
|
||||
batch_size: 256
|
||||
optimizer: sgd
|
||||
base_lr: 0.01
|
||||
momentum: 0.9
|
||||
nesterov: True
|
||||
weight_decay: 1e-3
|
||||
output_dir: experiments/cifar10h_noise_30/vanilla_self_supervision_v3
|
||||
log_period: 100
|
||||
checkpoint_period: 100
|
||||
tanh_regularisation: 0.30
|
||||
use_self_supervision: True
|
||||
self_supervision:
|
||||
checkpoints: ["logs/cifar10h/self_supervised/byol_seed_1/checkpoints/last.ckpt"]
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
test:
|
||||
batch_size: 128
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 120
|
||||
type: multistep
|
||||
milestones: [90]
|
||||
lr_decay: 0.1
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_label_smoothing: True
|
||||
use_random_affine: False
|
||||
use_random_color: True
|
||||
random_crop:
|
||||
scale: (0.9, 1.0)
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_color:
|
||||
brightness: 0.5
|
||||
contrast: 0.5
|
||||
saturation: 0.5
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
device: cuda
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: CIFAR10H
|
||||
noise_temperature: 10.0
|
||||
num_samples: 5000
|
||||
model:
|
||||
type: cifar
|
||||
name: resnet
|
||||
init_mode: kaiming_fan_out
|
||||
use_dropout: True
|
||||
resnet:
|
||||
depth: 110
|
||||
initial_channels: 16
|
||||
block_type: basic
|
||||
apply_l2_norm: True
|
||||
train:
|
||||
resume_epoch: 0
|
||||
seed: 1
|
||||
batch_size: 256
|
||||
optimizer: sgd
|
||||
base_lr: 0.1
|
||||
momentum: 0.9
|
||||
nesterov: True
|
||||
weight_decay: 1e-4
|
||||
output_dir: experiments/cifar10h_noise_30/vanilla_with_dropout
|
||||
log_period: 100
|
||||
checkpoint_period: 100
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
test:
|
||||
batch_size: 512
|
||||
dataloader:
|
||||
num_workers: 2
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 160
|
||||
type: multistep
|
||||
milestones: [80, 120]
|
||||
lr_decay: 0.1
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_label_smoothing: False
|
||||
use_random_affine: True
|
||||
use_random_color: True
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_affine:
|
||||
max_angle: 15
|
||||
max_horizontal_shift: 0.05
|
||||
max_vertical_shift: 0.05
|
||||
max_shear: 5
|
||||
random_color:
|
||||
brightness: 0.5
|
||||
contrast: 0.5
|
||||
saturation: 0.5
|
|
@ -0,0 +1,62 @@
|
|||
device: cuda:0
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: NoisyChestXray
|
||||
dataset_dir: /datadrive/rsna_train
|
||||
n_channels: 3
|
||||
n_classes: 2
|
||||
cxr_consolidation_noise_rate: 0.10
|
||||
model:
|
||||
type: cxr
|
||||
name: resnet50
|
||||
train:
|
||||
use_balanced_sampler: True
|
||||
pretrained: False
|
||||
seed: 5
|
||||
batch_size: 32
|
||||
optimizer: adam
|
||||
base_lr: 1e-5
|
||||
output_dir: cxr/nih10/vanilla
|
||||
log_period: 5
|
||||
checkpoint_period: 20
|
||||
use_co_teaching: False
|
||||
use_teacher_model: False
|
||||
dataloader:
|
||||
num_workers: 6
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 32
|
||||
dataloader:
|
||||
num_workers: 6
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
preprocess:
|
||||
center_crop_size: 224
|
||||
resize: 256
|
||||
scheduler:
|
||||
epochs: 150
|
||||
type: constant
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_random_affine: True
|
||||
use_random_color: True
|
||||
use_random_crop: True
|
||||
use_gamma_transform: False
|
||||
use_random_erasing: False
|
||||
add_gaussian_noise: False
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_affine:
|
||||
max_angle: 30
|
||||
max_horizontal_shift: 0.00
|
||||
max_vertical_shift: 0.00
|
||||
max_shear: 15
|
||||
random_color:
|
||||
brightness: 0.2
|
||||
contrast: 0.2
|
||||
saturation: 0.0
|
||||
random_crop:
|
||||
scale: (0.8, 1.0)
|
|
@ -0,0 +1,67 @@
|
|||
device: cuda:3
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: NoisyChestXray
|
||||
dataset_dir: /datadrive/rsna_train
|
||||
n_channels: 3
|
||||
n_classes: 2
|
||||
cxr_consolidation_noise_rate: 0.10
|
||||
model:
|
||||
type: cxr
|
||||
name: resnet50
|
||||
train:
|
||||
use_balanced_sampler: True
|
||||
pretrained: False
|
||||
seed: 5
|
||||
batch_size: 32
|
||||
optimizer: adam
|
||||
base_lr: 1e-5
|
||||
output_dir: cxr/nih10/co_teaching
|
||||
log_period: 5
|
||||
checkpoint_period: 20
|
||||
use_co_teaching: True
|
||||
use_teacher_model: False
|
||||
co_teaching_consistency_loss: False
|
||||
co_teaching_use_graph: False
|
||||
co_teaching_forget_rate: 0.15
|
||||
co_teaching_num_gradual: 0
|
||||
co_teaching_num_warmup: 20
|
||||
dataloader:
|
||||
num_workers: 6
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 16
|
||||
dataloader:
|
||||
num_workers: 6
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
preprocess:
|
||||
center_crop_size: 224
|
||||
resize: 256
|
||||
scheduler:
|
||||
epochs: 150
|
||||
type: constant
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_random_affine: True
|
||||
use_random_color: True
|
||||
use_random_crop: True
|
||||
use_gamma_transform: False
|
||||
use_random_erasing: False
|
||||
add_gaussian_noise: False
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_affine:
|
||||
max_angle: 30
|
||||
max_horizontal_shift: 0.00
|
||||
max_vertical_shift: 0.00
|
||||
max_shear: 15
|
||||
random_color:
|
||||
brightness: 0.2
|
||||
contrast: 0.2
|
||||
saturation: 0.0
|
||||
random_crop:
|
||||
scale: (0.8, 1.0)
|
|
@ -0,0 +1,72 @@
|
|||
device: cuda:3
|
||||
cudnn:
|
||||
benchmark: True
|
||||
deterministic: True
|
||||
dataset:
|
||||
name: NoisyChestXray
|
||||
dataset_dir: /datadrive/rsna_train
|
||||
n_channels: 3
|
||||
n_classes: 2
|
||||
cxr_consolidation_noise_rate: 0.10
|
||||
model:
|
||||
type: cxr
|
||||
name: resnet50
|
||||
train:
|
||||
use_balanced_sampler: True
|
||||
pretrained: False
|
||||
seed: 5
|
||||
batch_size: 32
|
||||
optimizer: adam
|
||||
base_lr: 1e-6
|
||||
output_dir: cxr/nih10/co_ssl
|
||||
log_period: 5
|
||||
checkpoint_period: 20
|
||||
use_co_teaching: True
|
||||
co_teaching_consistency_loss: False
|
||||
co_teaching_use_graph: False
|
||||
co_teaching_forget_rate: 0.15
|
||||
co_teaching_num_gradual: 0
|
||||
co_teaching_num_warmup: 10
|
||||
use_teacher_model: False
|
||||
use_self_supervision: True
|
||||
self_supervision:
|
||||
freeze_encoder: False
|
||||
checkpoints: ["/datadrive/InnerEye-DataQuality/logs/nih/ssup/byol_seed_3/checkpoints/epoch=999.ckpt",
|
||||
"/datadrive/InnerEye-DataQuality/logs/nih/ssup/byol_seed_3/checkpoints/epoch=999.ckpt"]
|
||||
dataloader:
|
||||
num_workers: 6
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
validation:
|
||||
batch_size: 32
|
||||
dataloader:
|
||||
num_workers: 6
|
||||
drop_last: False
|
||||
pin_memory: False
|
||||
scheduler:
|
||||
epochs: 100
|
||||
type: constant
|
||||
preprocess:
|
||||
center_crop_size: 224
|
||||
resize: 256
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_random_affine: True
|
||||
use_random_color: True
|
||||
use_random_crop: True
|
||||
use_gamma_transform: False
|
||||
use_random_erasing: False
|
||||
add_gaussian_noise: False
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_affine:
|
||||
max_angle: 30
|
||||
max_horizontal_shift: 0.00
|
||||
max_vertical_shift: 0.00
|
||||
max_shear: 15
|
||||
random_color:
|
||||
brightness: 0.2
|
||||
contrast: 0.2
|
||||
saturation: 0.0
|
||||
random_crop:
|
||||
scale: (0.8, 1.0)
|
|
@ -0,0 +1,5 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelectorJoint']
|
||||
model_name: 'Coteaching'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_co_teaching.yaml'
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelectorJoint']
|
||||
model_name: 'Vanilla'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet.yaml'
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelectorJoint']
|
||||
model_name: 'Self-Supervision-Joint'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_15/resnet_self_supervision_v3.yaml'
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelectorJoint']
|
||||
model_name: 'Coteaching'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_co_teaching.yaml'
|
||||
output_directory: 'outputs/cifar10h_30/coteaching'
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelectorJoint']
|
||||
model_name: 'Coteaching-Active'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_co_teaching.yaml'
|
||||
use_active_relabelling: True
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelectorJoint']
|
||||
model_name: 'Vanilla'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet.yaml'
|
||||
output_directory: 'outputs/cifar10h_30/vanilla'
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelectorJoint']
|
||||
model_name: 'Vanilla-Active'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet.yaml'
|
||||
use_active_relabelling: True
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
selector:
|
||||
type: ['BaldSelector']
|
||||
model_name: 'BALD-MC-Active'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_with_dropout.yaml'
|
||||
use_active_relabelling: True
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelectorJoint']
|
||||
model_name: 'Self-Supervision'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_self_supervision_v3.yaml'
|
||||
output_directory: 'outputs/cifar10h_30/ssup_v3'
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelectorJoint']
|
||||
model_name: 'Self-Supervision-Active'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cifar10h_noise_30/resnet_self_supervision_v3.yaml'
|
||||
use_active_relabelling: True
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelector']
|
||||
model_name: 'Coteaching'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cxr/resnet_coteaching.yaml'
|
|
@ -0,0 +1,5 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelector']
|
||||
model_name: 'SSL+Coteaching'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cxr/resnet_ssl_coteaching.yaml'
|
||||
use_active_relabelling: False
|
|
@ -0,0 +1,5 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelector']
|
||||
model_name: 'SSL+Coteaching+Active'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cxr/resnet_ssl_coteaching.yaml'
|
||||
use_active_relabelling: True
|
|
@ -0,0 +1,4 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelector']
|
||||
model_name: 'Vanilla'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cxr/resnet.yaml'
|
|
@ -0,0 +1,5 @@
|
|||
selector:
|
||||
type: ['PosteriorBasedSelector']
|
||||
model_name: 'Vanilla (heavy aug) Active'
|
||||
model_config_path: 'InnerEyeDataQuality/configs/models/cxr/resnet.yaml'
|
||||
use_active_relabelling: True
|
|
@ -0,0 +1,30 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
|
||||
config = ConfigNode()
|
||||
|
||||
# data selector
|
||||
config.selector = ConfigNode()
|
||||
config.selector.type = None
|
||||
config.selector.model_name = None
|
||||
config.selector.model_config_path = None
|
||||
config.selector.use_active_relabelling = False
|
||||
|
||||
# Other selector parameters (unused)
|
||||
config.selector.training_dynamics_data_path = None
|
||||
config.selector.burnout_period = 0
|
||||
config.selector.number_samples_to_relabel = 10
|
||||
|
||||
# output files
|
||||
config.selector.output_directory = None
|
||||
|
||||
# tensorboard
|
||||
config.tensorboard = ConfigNode()
|
||||
config.tensorboard.save_events = False
|
||||
|
||||
def get_default_selector_config() -> ConfigNode:
|
||||
return config.clone()
|
|
@ -0,0 +1,140 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torchvision
|
||||
import PIL.Image
|
||||
|
||||
from default_paths import FIGURE_DIR
|
||||
from InnerEyeDataQuality.datasets.label_distribution import LabelDistribution
|
||||
from InnerEyeDataQuality.datasets.label_noise_model import get_cifar10h_confusion_matrix, \
|
||||
get_cifar10_asym_noise_model, get_cifar10_sym_noise_model
|
||||
from InnerEyeDataQuality.utils.plot import plot_confusion_matrix
|
||||
from InnerEyeDataQuality.datasets.cifar10_utils import get_cifar10_label_names
|
||||
|
||||
|
||||
class CIFAR10AsymNoise(torchvision.datasets.CIFAR10):
|
||||
"""
|
||||
Dataset class for the CIFAR10 dataset where target labels are sampled from a confusion matrix.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: str,
|
||||
train: bool,
|
||||
transform: Optional[Callable] = None,
|
||||
download: bool = True,
|
||||
use_fixed_labels: bool = True,
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
"""
|
||||
:param root: The directory in which the CIFAR10 images will be stored.
|
||||
:param train: If True, creates dataset from training set, otherwise creates from test set.
|
||||
:param transform: Transform to apply to the images.
|
||||
:param download: Whether to download the dataset if it is not already in the local disk.
|
||||
:param use_fixed_labels: If true labels are sampled only once and are kept fixed. If false labels are sampled at
|
||||
each get_item() function call from label distribution.
|
||||
:param seed: The random seed that defines which samples are train/test and which labels are sampled.
|
||||
"""
|
||||
super().__init__(root, train=train, transform=transform, target_transform=None, download=download)
|
||||
|
||||
self.seed = seed
|
||||
self.targets = np.array(self.targets, dtype=np.int64) # type: ignore
|
||||
self.num_classes = np.unique(self.targets, return_counts=False).size
|
||||
self.num_samples = len(self.data)
|
||||
self.label_counts = np.eye(self.num_classes, dtype=np.int64)[self.targets]
|
||||
self.np_random_state = np.random.RandomState(seed)
|
||||
self.use_fixed_labels = use_fixed_labels
|
||||
self.clean_targets = np.argmax(self.label_counts, axis=1)
|
||||
logging.info(f"Preparing dataset: CIFAR10-Asym-Noise (N={self.num_samples})")
|
||||
|
||||
# Create label distribution for simulation of label adjudication
|
||||
self.label_distribution = LabelDistribution(seed, self.label_counts, temperature=1.0)
|
||||
|
||||
# Add asymmetric noise on the labels
|
||||
self.noise_model = self.create_noise_transition_model(self.targets, self.num_classes, "cifar10_sym")
|
||||
|
||||
# Sample fixed labels from the distribution
|
||||
if use_fixed_labels:
|
||||
self.targets = self.sample_labels_from_model()
|
||||
# Identify label noise cases
|
||||
noise_rate = np.mean(self.clean_targets != self.targets) * 100.0
|
||||
|
||||
# Check the class distribution after sampling
|
||||
class_distribution = self.get_class_frequencies(targets=self.targets, num_classes=self.num_classes)
|
||||
|
||||
# Log dataset details
|
||||
logging.info(f"Class distribution (%) (true labels): {class_distribution * 100.0}")
|
||||
logging.info(f"Label noise rate: {noise_rate}")
|
||||
|
||||
@staticmethod
|
||||
def create_noise_transition_model(labels: np.ndarray, num_classes: int, noise_model: str) -> np.ndarray:
|
||||
logging.info(f"Using {noise_model} label noise model")
|
||||
if noise_model == "cifar10h":
|
||||
transition_matrix = get_cifar10h_confusion_matrix(temperature=2.0)
|
||||
elif noise_model == "cifar10_asym":
|
||||
transition_matrix = get_cifar10_asym_noise_model(eta=0.4)
|
||||
elif noise_model == "cifar10_sym":
|
||||
transition_matrix = get_cifar10_sym_noise_model(eta=0.4)
|
||||
else:
|
||||
raise ValueError("Unknown noise transition model")
|
||||
assert(np.all(np.sum(transition_matrix, axis=1) - 1.00 < 1e-6)) # Checks = it sums up to one.
|
||||
|
||||
# Visualise the noise model
|
||||
plot_confusion_matrix(list(), list(), get_cifar10_label_names(), cm=transition_matrix, save_path=FIGURE_DIR)
|
||||
|
||||
# Compute the expected noise rate
|
||||
assert labels.ndim == 1
|
||||
exp_noise_rate = 1.0 - np.sum(
|
||||
np.diag(transition_matrix) * CIFAR10AsymNoise.get_class_frequencies(labels, num_classes))
|
||||
logging.info(f"Expected noise rate (transition model): {exp_noise_rate}")
|
||||
|
||||
return transition_matrix
|
||||
|
||||
def sample_labels_from_model(self, sample_index: Optional[int] = None) -> np.ndarray:
|
||||
# Sample based on the transition matrix and original labels (labels, transition mat, seed)
|
||||
if sample_index is not None:
|
||||
cur_label = self.targets[sample_index]
|
||||
label = self.np_random_state.choice(self.num_classes, 1, p=self.noise_model[cur_label, :])[0]
|
||||
return label
|
||||
|
||||
noisy_targets = np.zeros_like(self.targets)
|
||||
for ii in range(self.num_samples):
|
||||
cur_label = self.targets[ii]
|
||||
noisy_targets[ii] = self.np_random_state.choice(self.num_classes, 1, p=self.noise_model[cur_label, :])[0]
|
||||
|
||||
return noisy_targets
|
||||
|
||||
@staticmethod
|
||||
def get_class_frequencies(targets: np.ndarray, num_classes: int) -> np.ndarray:
|
||||
"""
|
||||
Returns normalised frequency of each semantic class
|
||||
"""
|
||||
assert targets.ndim == 1
|
||||
class_ids, class_counts = np.unique(targets, return_counts=True)
|
||||
class_distribution = np.zeros(num_classes, dtype=np.float)
|
||||
for ii in range(num_classes):
|
||||
if np.any(class_ids == ii):
|
||||
class_distribution[ii] = class_counts[class_ids == ii]/targets.size
|
||||
|
||||
return class_distribution
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[PIL.Image.Image, int]:
|
||||
"""
|
||||
:param index: The index of the sample to be fetched
|
||||
:return: The image and label tensors
|
||||
"""
|
||||
img = PIL.Image.fromarray(self.data[index])
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.use_fixed_labels:
|
||||
target = self.targets[index]
|
||||
else:
|
||||
target = self.sample_labels_from_model(sample_index=index)
|
||||
|
||||
return img, int(target)
|
|
@ -0,0 +1,153 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
from typing import Callable, Optional, Tuple, List, Generator
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torchvision
|
||||
from InnerEyeDataQuality.datasets.label_distribution import LabelDistribution
|
||||
from InnerEyeDataQuality.datasets.tools import get_instance_noise_model
|
||||
from InnerEyeDataQuality.utils.generic import convert_labels_to_one_hot
|
||||
from pl_bolts.models.self_supervised.resnets import resnet50_bn
|
||||
from InnerEyeDataQuality.datasets.cifar10h import CIFAR10H
|
||||
|
||||
def chunks(lst: List, n: int) -> Generator:
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i:i + n]
|
||||
|
||||
class CIFAR10IDN(torchvision.datasets.CIFAR10):
|
||||
"""
|
||||
Dataset class for the CIFAR10 dataset where target labels are sampled from a confusion matrix.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: str,
|
||||
train: bool,
|
||||
noise_rate: float,
|
||||
transform: Optional[Callable] = None,
|
||||
download: bool = True,
|
||||
use_fixed_labels: bool = True,
|
||||
seed: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
:param root: The directory in which the CIFAR10 images will be stored.
|
||||
:param train: If True, creates dataset from training set, otherwise creates from test set.
|
||||
:param transform: Transform to apply to the images.
|
||||
:param download: Whether to download the dataset if it is not already in the local disk.
|
||||
:param noise_rate: Expected noise rate in the sampled labels.
|
||||
:param use_fixed_labels: If true labels are sampled only once and are kept fixed. If false labels are sampled at
|
||||
each get_item() function call from label distribution.
|
||||
:param seed: The random seed that defines which samples are train/test and which labels are sampled.
|
||||
"""
|
||||
super().__init__(root, train=train, transform=transform, target_transform=None, download=download)
|
||||
self.seed = seed
|
||||
self.targets = np.array(self.targets, dtype=np.int64) # type: ignore
|
||||
self.num_classes = np.unique(self.targets, return_counts=False).size
|
||||
self.num_samples = len(self.data)
|
||||
self.clean_targets = np.copy(self.targets)
|
||||
self.np_random_state = np.random.RandomState(seed)
|
||||
self.use_fixed_labels = use_fixed_labels
|
||||
self.indices = np.array(range(self.num_samples))
|
||||
logging.info(f"Preparing dataset: CIFAR10-IDN (N={self.num_samples})")
|
||||
|
||||
# Set seed for torch operations
|
||||
initial_state = torch.get_rng_state()
|
||||
torch.manual_seed(self.seed)
|
||||
torch.cuda.manual_seed_all(self.seed)
|
||||
|
||||
# Collect image embeddings
|
||||
embeddings = self.get_cnn_image_embeddings(self.data)
|
||||
targets = torch.from_numpy(self.targets)
|
||||
|
||||
# Create label distribution for simulation of label adjudication
|
||||
label_counts = convert_labels_to_one_hot(self.clean_targets, n_classes=self.num_classes) if train else \
|
||||
CIFAR10H.download_cifar10h_labels(self.root)
|
||||
self.label_distribution = LabelDistribution(seed, label_counts, temperature=1.0)
|
||||
|
||||
# Add asymmetric noise on the labels
|
||||
self.noise_models = get_instance_noise_model(n=noise_rate,
|
||||
dataset=zip(embeddings, targets),
|
||||
labels=targets,
|
||||
num_classes=self.num_classes,
|
||||
feature_size=embeddings.shape[1],
|
||||
norm_std=0.01,
|
||||
seed=self.seed)
|
||||
|
||||
if self.use_fixed_labels:
|
||||
# Sample target labels
|
||||
self.targets = self.sample_labels_from_model()
|
||||
|
||||
# Check the class distribution after sampling
|
||||
class_distribution = self.get_class_frequencies(targets=self.targets, num_classes=self.num_classes)
|
||||
noise_rate = np.mean(self.clean_targets != self.targets) * 100.0
|
||||
|
||||
# Log dataset details
|
||||
logging.info(f"Class distribution (%) (true labels): {class_distribution * 100.0}")
|
||||
logging.info(f"Label noise rate: {noise_rate}")
|
||||
else:
|
||||
self.targets = None
|
||||
|
||||
# Restore initial state
|
||||
torch.set_rng_state(initial_state)
|
||||
|
||||
def sample_labels_from_model(self, sample_index: Optional[int] = None) -> np.ndarray:
|
||||
"""
|
||||
Samples class labels for each data point based on true label and instance dependent noise model
|
||||
"""
|
||||
classes = [i for i in range(self.num_classes)]
|
||||
if sample_index is not None:
|
||||
_t = self.np_random_state.choice(classes, p=self.noise_models[sample_index])
|
||||
else:
|
||||
_t = [self.np_random_state.choice(classes, p=self.noise_models[i]) for i in range(self.num_samples)]
|
||||
return np.array(_t)
|
||||
|
||||
@staticmethod
|
||||
def get_class_frequencies(targets: np.ndarray, num_classes: int) -> np.ndarray:
|
||||
"""
|
||||
Returns normalised frequency of each semantic class
|
||||
"""
|
||||
assert targets.ndim == 1
|
||||
class_ids, class_counts = np.unique(targets, return_counts=True)
|
||||
class_distribution = np.zeros(num_classes, dtype=np.float)
|
||||
for ii in range(num_classes):
|
||||
if np.any(class_ids == ii):
|
||||
class_distribution[ii] = class_counts[class_ids == ii]/targets.size
|
||||
|
||||
return class_distribution
|
||||
|
||||
@staticmethod
|
||||
def get_cnn_image_embeddings(data: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Extracts image embeddings using a pre-trained model
|
||||
"""
|
||||
num_samples = data.shape[0]
|
||||
embeddings = list()
|
||||
data = torch.from_numpy(data).float().cuda()
|
||||
encoder = resnet50_bn(return_all_feature_maps=False, pretrained=True).cuda().eval()
|
||||
with torch.no_grad():
|
||||
for i in chunks(list(range(num_samples)), n=100):
|
||||
input = data[i].permute(0, 3, 1, 2)
|
||||
embeddings.append(encoder(input)[-1].cpu())
|
||||
return torch.cat(embeddings, dim=0).view(num_samples, -1)
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[PIL.Image.Image, int]:
|
||||
"""
|
||||
:param index: The index of the sample to be fetched
|
||||
:return: The image and label tensors
|
||||
"""
|
||||
img = PIL.Image.fromarray(self.data[index])
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.use_fixed_labels:
|
||||
target = self.targets[index]
|
||||
else:
|
||||
target = self.sample_labels_from_model(sample_index=index)
|
||||
|
||||
return img, int(target)
|
|
@ -0,0 +1,165 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def split_dataset(labels: np.ndarray,
|
||||
reference_split: float,
|
||||
shuffle: bool = True,
|
||||
seed: int = 1234) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
:param labels: Complete label array that is split into two subsets, reference and noisy test set.
|
||||
:param reference_split: Ratio of total samples that will be put in the reference set.
|
||||
:param shuffle: If set to true, rows of the label matrix are shuffled prior to split.
|
||||
:param seed: Random seed used in sample shuffle
|
||||
"""
|
||||
# Create two testing sets from CIFAR10H Test Samples (10k - 10 classes)
|
||||
num_samples = labels.shape[0]
|
||||
num_samples_set1 = int(num_samples * reference_split)
|
||||
perm = np.random.RandomState(seed=seed).permutation(num_samples) if shuffle else np.array(range(num_samples))
|
||||
d_set1 = labels[perm[:num_samples_set1], :]
|
||||
d_set2 = labels[perm[num_samples_set1:], :]
|
||||
|
||||
return d_set1, d_set2, perm
|
||||
|
||||
|
||||
def get_cifar10h_labels(reference_split: float = 0.5,
|
||||
shuffle: bool = True,
|
||||
immutable: bool = True) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
:param reference_split: The sample ratio between gold standard test set and full test set
|
||||
:param shuffle: Shuffle samples prior to split.
|
||||
:param immutable: If set to True, the returned arrays are only read only.
|
||||
"""
|
||||
cifar10h_counts = download_cifar10h_data() # Num_samples x n_classes
|
||||
d_split1_standard, d_split2_complete, permutation = split_dataset(cifar10h_counts, reference_split, shuffle)
|
||||
d_split1_permutation = permutation[:d_split1_standard.shape[0]]
|
||||
|
||||
if immutable:
|
||||
d_split1_standard.setflags(write=False)
|
||||
d_split2_complete.setflags(write=False)
|
||||
|
||||
return d_split1_standard, d_split1_permutation
|
||||
|
||||
|
||||
def download_cifar10h_data() -> np.ndarray:
|
||||
"""
|
||||
Pulls cifar10h label data stream and returns it in numpy array.
|
||||
"""
|
||||
|
||||
url = 'https://raw.githubusercontent.com/jcpeterson/cifar-10h/master/data/cifar10h-counts.npy'
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
if response.status_code == requests.codes.ok:
|
||||
cifar10h_data = np.load(io.BytesIO(response.content))
|
||||
else:
|
||||
raise ValueError('CIFAR10H content was not found.')
|
||||
|
||||
return cifar10h_data
|
||||
|
||||
|
||||
def download_cifar10_data() -> Path:
|
||||
"""
|
||||
Download CIFAR10 dataset and returns path to the test set
|
||||
"""
|
||||
import wget
|
||||
local_path = Path.cwd() / 'InnerEyeDataQuality' / 'downloaded_data'
|
||||
local_path_to_test_batch = local_path / 'cifar-10-batches-py/test_batch'
|
||||
|
||||
if not local_path_to_test_batch.exists():
|
||||
local_path.mkdir(parents=True, exist_ok=True)
|
||||
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
|
||||
path_to_tar = local_path / 'cifar10.tar.gz'
|
||||
wget.download(url, str(path_to_tar))
|
||||
tf = tarfile.open(str(path_to_tar))
|
||||
tf.extractall(local_path)
|
||||
os.remove(path_to_tar)
|
||||
|
||||
return local_path_to_test_batch
|
||||
|
||||
|
||||
def get_cifar10_label_names(file: Optional[Path] = None) -> List[str]:
|
||||
"""
|
||||
TBD
|
||||
"""
|
||||
if file:
|
||||
dict = load_cifar10_file(file)
|
||||
label_names = [_s.decode("utf-8") for _s in dict[b"label_names"]]
|
||||
else:
|
||||
label_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
||||
|
||||
return label_names
|
||||
|
||||
|
||||
def load_cifar10_file(file: Path) -> Dict:
|
||||
"""
|
||||
TBD
|
||||
"""
|
||||
with open(file, 'rb') as fo:
|
||||
dict = pickle.load(fo, encoding='bytes')
|
||||
|
||||
return dict
|
||||
|
||||
|
||||
def plot_cifar10_images(sample_ids: List[int], save_directory: Path) -> None:
|
||||
"""
|
||||
Displays a set of CIFAR-10 Images based on the input sample ids. In the title of the figure,
|
||||
label distribution is displayed as well to understand the sample difficulty.
|
||||
"""
|
||||
|
||||
path_cifar10_test_batch = download_cifar10_data()
|
||||
test_batch = load_cifar10_file(path_cifar10_test_batch)
|
||||
plot_images(test_batch[b'data'], sample_ids, save_directory=save_directory)
|
||||
|
||||
|
||||
def plot_images(images: np.ndarray,
|
||||
selected_sample_ids: List[int],
|
||||
cifar10h_labels: Optional[np.ndarray] = None,
|
||||
label_names: Optional[List[str]] = None,
|
||||
save_directory: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Displays a set of CIFAR-10 Images based on the input sample ids. In the title of the figure,
|
||||
label distribution is displayed as well to understand the sample difficulty.
|
||||
"""
|
||||
|
||||
f, ax = plt.subplots(figsize=(2, 2))
|
||||
for sample_id in selected_sample_ids:
|
||||
img = np.reshape(images[sample_id, :], (3, 32, 32)).transpose(1, 2, 0)
|
||||
ax.imshow(img)
|
||||
|
||||
if (cifar10h_labels is not None) and (label_names is not None):
|
||||
num_k_classes = 3
|
||||
label_distribution = cifar10h_labels[sample_id, :]
|
||||
k_min_val = np.sort(label_distribution)[-num_k_classes]
|
||||
available_classes = np.where(label_distribution >= k_min_val)[0]
|
||||
class_counts = label_distribution[available_classes]
|
||||
class_names = [label_names[_c] for _c in available_classes]
|
||||
ax_title = ''.join([a + '_' + str(b) + ' ' for a, b in zip(class_names, class_counts)])
|
||||
ax.set_title(ax_title)
|
||||
else:
|
||||
ax_title = f'CIFAR10H - Sample ID {sample_id}'
|
||||
ax.set_title(ax_title)
|
||||
|
||||
if save_directory:
|
||||
save_directory.mkdir(parents=True, exist_ok=True)
|
||||
f.savefig(save_directory / f"{sample_id}.png", bbox_inches='tight')
|
||||
ax.clear()
|
||||
plt.close(f)
|
||||
else:
|
||||
plt.show()
|
||||
f, ax = plt.subplots(figsize=(2, 2))
|
|
@ -0,0 +1,184 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import requests
|
||||
import torchvision
|
||||
|
||||
from InnerEyeDataQuality.datasets.cifar10_utils import get_cifar10_label_names
|
||||
from InnerEyeDataQuality.datasets.label_distribution import LabelDistribution
|
||||
from InnerEyeDataQuality.evaluation.metrics import compute_label_entropy
|
||||
from InnerEyeDataQuality.selection.simulation_statistics import SimulationStats, get_ambiguous_sample_ids
|
||||
from InnerEyeDataQuality.utils.generic import convert_labels_to_one_hot
|
||||
|
||||
|
||||
TOTAL_CIFAR10H_DATASET_SIZE = 10000
|
||||
|
||||
class CIFAR10H(torchvision.datasets.CIFAR10):
|
||||
"""
|
||||
Dataset class for the CIFAR10H dataset. The CIFAR10H dataset is the CIFAR10 test set but all the samples have
|
||||
been labelled my multiple annotators
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: str,
|
||||
transform: Optional[Callable] = None,
|
||||
num_samples: Optional[int] = None,
|
||||
preset_indices: Optional[np.ndarray] = None,
|
||||
seed: int = 1234,
|
||||
noise_temperature: float = 1.0,
|
||||
noise_offset: float = 0.0) -> None:
|
||||
"""
|
||||
:param root: The directory in which the CIFAR10 images will be stored
|
||||
:param transform: Transform to apply to the images
|
||||
:param num_samples: The number of samples to use out of a maximum of TOTAL_CIFAR10H_DATASET_SIZE
|
||||
:param seed: The random seed that defines which samples are train/test and which labels are sampled
|
||||
:param preset_indices: Image indices that will be used to create a subset of CIFAR10H dataset.
|
||||
If not specified and num_samples < 10000 then random sub-selection is performed.
|
||||
:param shuffle: Whether to shuffle the data before splitting into training and test sets.
|
||||
:param noise_temperature: A temperature a value that is used to temperature scale the label distribution.
|
||||
:param noise_offset: Offset parameter to control the noise rate in sampling initial labels.
|
||||
"""
|
||||
super().__init__(root, train=False, transform=transform, target_transform=None, download=True)
|
||||
num_samples = TOTAL_CIFAR10H_DATASET_SIZE if num_samples is None else num_samples
|
||||
|
||||
self.seed = seed
|
||||
cifar10h_labels = self.download_cifar10h_labels(self.root)
|
||||
self.num_classes = cifar10h_labels.shape[1]
|
||||
assert cifar10h_labels.shape[0] == TOTAL_CIFAR10H_DATASET_SIZE
|
||||
assert self.num_classes == 10
|
||||
assert 0 < num_samples <= TOTAL_CIFAR10H_DATASET_SIZE
|
||||
self.num_samples = num_samples
|
||||
|
||||
# Create a set indices that
|
||||
self.indices = self.get_dataset_indices(num_samples, cifar10h_labels, keep_hard_samples=True, seed=seed) \
|
||||
if preset_indices is None else preset_indices
|
||||
self.verify_data_indices()
|
||||
self.label_counts = cifar10h_labels[self.indices]
|
||||
self.label_counts.flags.writeable = False
|
||||
self.true_label_entropy = compute_label_entropy(label_counts=self.label_counts)
|
||||
|
||||
self.label_distribution = LabelDistribution(seed, self.label_counts, noise_temperature, noise_offset)
|
||||
self.targets = self.label_distribution.sample_initial_labels_for_all()
|
||||
|
||||
# Check the class distribution
|
||||
_, class_counts = np.unique(self.targets, return_counts=True)
|
||||
class_distribution = np.array([_c/self.num_samples for _c in class_counts])
|
||||
logging.info(f"Preparing dataset: CIFAR10H (N={self.num_samples})")
|
||||
logging.info(f"Class distribution (%) (true labels): {class_distribution * 100.0}")
|
||||
self.clean_targets = np.argmax(self.label_counts, axis=1)
|
||||
# Identify true ambiguous and clear label noise cases
|
||||
self._identify_sample_types()
|
||||
|
||||
def _identify_sample_types(self) -> None:
|
||||
"""
|
||||
Stores and logs clear label noise and ambiguous case types.
|
||||
"""
|
||||
label_stats = SimulationStats(name="cifar10h", true_label_counts=self.label_counts,
|
||||
initial_labels=convert_labels_to_one_hot(self.targets, self.num_classes))
|
||||
self.ambiguous_mislabelled_cases = label_stats.mislabelled_ambiguous_sample_ids[0]
|
||||
self.clear_mislabeled_cases = label_stats.mislabelled_not_ambiguous_sample_ids[0]
|
||||
self.ambiguity_metric_args = {"ambiguous_mislabelled_ids": self.ambiguous_mislabelled_cases,
|
||||
"clear_mislabelled_ids": self.clear_mislabeled_cases,
|
||||
"true_label_entropy": self.true_label_entropy}
|
||||
|
||||
# Log dataset details
|
||||
logging.info(f"Ambiguous mislabeled cases: {100 * len(self.ambiguous_mislabelled_cases) / self.num_samples}%")
|
||||
logging.info(f"Clear mislabeled cases: {100 * len(self.clear_mislabeled_cases) / self.num_samples}%\n")
|
||||
|
||||
@classmethod
|
||||
def download_cifar10h_labels(self, root: str = ".") -> np.ndarray:
|
||||
"""
|
||||
Pulls cifar10h label data stream and returns it in numpy array.
|
||||
"""
|
||||
try:
|
||||
cifar10h_labels = np.load(Path(root) / "cifar10h-counts.npy")
|
||||
except FileNotFoundError:
|
||||
url = 'https://raw.githubusercontent.com/jcpeterson/cifar-10h/master/data/cifar10h-counts.npy'
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
if response.status_code == requests.codes.ok:
|
||||
cifar10h_labels = np.load(io.BytesIO(response.content))
|
||||
else:
|
||||
raise ValueError('Failed to download CIFAR10H labels!')
|
||||
|
||||
return cifar10h_labels
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[PIL.Image.Image, int]:
|
||||
"""
|
||||
:param index: The index of the sample to be fetched
|
||||
:return: The image and label tensors
|
||||
"""
|
||||
img = PIL.Image.fromarray(self.data[self.indices[index]])
|
||||
target = self.targets[index]
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, int(target)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
||||
:return: The size of the dataset
|
||||
"""
|
||||
return len(self.indices)
|
||||
|
||||
def get_label_names(self) -> List[str]:
|
||||
return get_cifar10_label_names()
|
||||
|
||||
def verify_data_indices(self) -> None:
|
||||
assert isinstance(self.indices, np.ndarray)
|
||||
assert self.indices.size == self.num_samples
|
||||
assert np.all(self.indices < TOTAL_CIFAR10H_DATASET_SIZE)
|
||||
assert np.all(0 <= self.indices)
|
||||
_, c = np.unique(self.indices, return_counts=True)
|
||||
assert np.all(c == 1)
|
||||
|
||||
def get_dataset_indices(self,
|
||||
num_samples: int,
|
||||
true_label_counts: np.ndarray,
|
||||
keep_hard_samples: bool,
|
||||
seed: int = 1234) -> np.ndarray:
|
||||
"""
|
||||
Function to choose a subset of the CIFAR10H dataset. Returns selected subset of sample
|
||||
indices in a shuffled order.
|
||||
|
||||
:param num_samples: Number of samples in the selected subset.
|
||||
If the full dataset size is specified then indices are just shuffled and returned.
|
||||
:param true_label_counts: True label counts of CIFAR10H images (num_samples x num_classes)
|
||||
:param keep_hard_samples: If set to True, all hard examples are kept in the selected subset of points.
|
||||
:param seed: Random seed used in shuffling data indices.
|
||||
"""
|
||||
random_state = np.random.RandomState(seed=seed)
|
||||
|
||||
assert num_samples <= TOTAL_CIFAR10H_DATASET_SIZE
|
||||
if (not keep_hard_samples) or (num_samples == TOTAL_CIFAR10H_DATASET_SIZE):
|
||||
indices = random_state.permutation(true_label_counts.shape[0])
|
||||
return indices[:num_samples]
|
||||
|
||||
# Identify difficult samples and keep them in the dataset
|
||||
hard_sample_indices = get_ambiguous_sample_ids(true_label_counts)
|
||||
if hard_sample_indices.shape[0] > num_samples:
|
||||
logging.info(f"Total number of hard samples: {hard_sample_indices.shape[0]} and requested: {num_samples}")
|
||||
hard_sample_indices = hard_sample_indices[:num_samples]
|
||||
num_hard_samples = hard_sample_indices.shape[0]
|
||||
|
||||
# Sample the remaining indices randomly and aggregate
|
||||
remaining_indices = np.setdiff1d(range(TOTAL_CIFAR10H_DATASET_SIZE), hard_sample_indices)
|
||||
easy_sample_indices = random_state.choice(remaining_indices, num_samples - num_hard_samples, replace=False)
|
||||
indices = np.concatenate([hard_sample_indices, easy_sample_indices], axis=0)
|
||||
random_state.shuffle(indices)
|
||||
|
||||
# Assert that there are no repeated indices
|
||||
_, _counts = np.unique(indices, return_counts=True)
|
||||
assert not np.any(_counts > 1)
|
||||
|
||||
return indices
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pydicom as dicom
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
KAGGLE_TOTAL_SIZE = 26684
|
||||
|
||||
|
||||
class KaggleCXR(Dataset):
|
||||
def __init__(self,
|
||||
data_directory: str,
|
||||
use_training_split: bool,
|
||||
train_fraction: float = 0.8,
|
||||
seed: int = 1234,
|
||||
shuffle: bool = True,
|
||||
transform: Optional[Callable] = None,
|
||||
num_samples: int = None,
|
||||
return_index: bool = True) -> None:
|
||||
"""
|
||||
Class for the full Kaggle RSNA Pneumonia Detection Dataset.
|
||||
|
||||
:param data_directory: the directory containing all training images from the Challenge (stage 1) as well as the
|
||||
dataset.csv containing the kaggle and the original labels.
|
||||
:param use_training_split: whether to return the training or the validation split of the dataset.
|
||||
:param train_fraction: the proportion of samples to use for training
|
||||
:param seed: random seed to use for dataset creation
|
||||
:param shuffle: whether to shuffle the dataset prior to spliting between validation and training
|
||||
:param transform: a preprocessing function that takes a PIL image as input and returns a tensor
|
||||
:param num_samples: number of the samples to return (has to been smaller than the dataset split)
|
||||
"""
|
||||
|
||||
self.data_directory = Path(data_directory)
|
||||
if not self.data_directory.exists():
|
||||
logging.error(
|
||||
f"The data directory {self.data_directory} does not exist. Make sure to download to Kaggle data "
|
||||
f"first.The kaggle dataset can "
|
||||
"be acceded via the Kaggle CLI kaggle competitions download -c rsna-pneumonia-detection-challenge or "
|
||||
"on the main page of the challenge "
|
||||
"https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data?select=stage_2_train_images")
|
||||
|
||||
self.train = use_training_split
|
||||
self.train_fraction = train_fraction
|
||||
self.seed = seed
|
||||
self.random_state = np.random.RandomState(seed)
|
||||
full_dataset = pd.read_csv(self.data_directory / "dataset.csv")
|
||||
self.dataset_dataframe = full_dataset
|
||||
self.transforms = transform
|
||||
|
||||
targets = self.dataset_dataframe.label.values.astype(np.int64)
|
||||
subjects_ids = self.dataset_dataframe.subject.values
|
||||
|
||||
self.num_classes = 2
|
||||
self.num_datapoints = len(self.dataset_dataframe)
|
||||
all_indices = np.arange(len(self.dataset_dataframe))
|
||||
|
||||
# ------------- Split the data into training and validation sets ------------- #
|
||||
num_samples_set1 = int(self.num_datapoints * self.train_fraction)
|
||||
sampled_indices = self.random_state.permutation(all_indices) \
|
||||
if shuffle else all_indices
|
||||
train_indices = sampled_indices[:num_samples_set1]
|
||||
val_indices = sampled_indices[num_samples_set1:]
|
||||
self.indices = train_indices if use_training_split else val_indices
|
||||
|
||||
# ------------- Select subset of current split ------------- #
|
||||
if num_samples is not None:
|
||||
assert 0 < num_samples <= len(self.indices)
|
||||
self.indices = self.indices[:num_samples]
|
||||
|
||||
self.subject_ids = subjects_ids[self.indices]
|
||||
|
||||
self.targets = targets[self.indices].reshape(-1)
|
||||
|
||||
dataset_type = "TRAIN" if use_training_split else "VAL"
|
||||
logging.info(f"Proportion of positive labels - {dataset_type}: {np.mean(self.targets)}")
|
||||
logging.info(f"Number samples - {dataset_type}: {self.targets.shape[0]}")
|
||||
self.return_index = return_index
|
||||
self.weight = np.mean(self.targets)
|
||||
logging.info(f"Weight negative {self.weight:.2f} - weight positive {(1 - self.weight):.2f}")
|
||||
|
||||
def __getitem__(self, index: int) -> Union[Tuple[int, PIL.Image.Image, int], Tuple[PIL.Image.Image, int]]:
|
||||
"""
|
||||
|
||||
:param index: The index of the sample to be fetched
|
||||
:return: The image and label tensors
|
||||
"""
|
||||
subject_id = self.subject_ids[index]
|
||||
filename = self.data_directory / f"{subject_id}.dcm"
|
||||
target = self.targets[index]
|
||||
scan_image = dicom.dcmread(filename).pixel_array
|
||||
scan_image = Image.fromarray(scan_image)
|
||||
if self.transforms is not None:
|
||||
scan_image = self.transforms(scan_image)
|
||||
if self.return_index:
|
||||
return index, scan_image, int(target)
|
||||
return scan_image, int(target)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
||||
:return: The size of the dataset
|
||||
"""
|
||||
return len(self.indices)
|
||||
|
||||
def get_label_names(self) -> List[str]:
|
||||
return ["Normal", "Opacity"]
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
|
||||
class LabelDistribution(object):
|
||||
"""
|
||||
LabelDistribution class handles sampling from a label distribution with reproducible behavior given a seed
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
seed: int,
|
||||
label_counts: np.ndarray,
|
||||
temperature: float = 1.0,
|
||||
offset: float = 0.0) -> None:
|
||||
"""
|
||||
:param seed: The random seed used to ensure reproducible behaviour
|
||||
:param label_counts: An array of shape (num_samples, num_classes) where each entry represents the number of
|
||||
labels available for each sample and class
|
||||
:param temperature: A temperature a value that will be used to temperature scale the distribution, default is
|
||||
1.0 which is equivalent to no scaling; temperature must be greater than 0.0, values between 0 and 1 will result
|
||||
in a sharper distribution and values greater than 1 in a more uniform distribution over classes.
|
||||
:param offset: Offset parameter to control the noise rate in sampling initial labels.
|
||||
All classes are assigned a uniform fixed offset amount.
|
||||
"""
|
||||
assert label_counts.dtype == np.int64
|
||||
assert label_counts.ndim == 2
|
||||
self.num_classes = label_counts.shape[1]
|
||||
self.num_samples = label_counts.shape[0]
|
||||
self.seed = seed
|
||||
self.random_state = np.random.RandomState(seed)
|
||||
self.label_counts = label_counts
|
||||
self.temperature = temperature
|
||||
|
||||
# make the distribution
|
||||
self.distribution = label_counts / np.sum(label_counts, axis=1, keepdims=True)
|
||||
assert np.isfinite(self.distribution).all()
|
||||
assert self.temperature > 0
|
||||
|
||||
# scale distribution based on temperature and offset
|
||||
_d = np.power(self.distribution, 1. / temperature)
|
||||
_d = _d / np.sum(_d, axis=1, keepdims=True)
|
||||
self.distribution_temp_scaled = self.add_noise_to_distribution(offset, _d, 'asym') if offset > 0.0 else _d
|
||||
|
||||
# check if there are multiple labels per data point
|
||||
self.is_multi_label_per_sample = np.all(np.sum(label_counts, axis=1) > 1.0)
|
||||
|
||||
def sample_initial_labels_for_all(self) -> np.ndarray:
|
||||
"""
|
||||
Sample one label for each sample in the dataset according to its label distribution
|
||||
:return: None
|
||||
"""
|
||||
if not self.is_multi_label_per_sample:
|
||||
RuntimeWarning("Sampling labels from one-hot encoded distribution - Multi labels are not available")
|
||||
|
||||
sampling_fn = lambda p: self.random_state.choice(self.num_classes, 1, p=p)
|
||||
|
||||
return np.squeeze(np.apply_along_axis(sampling_fn, arr=self.distribution_temp_scaled, axis=1))
|
||||
|
||||
def sample(self, sample_idx: int) -> int:
|
||||
"""
|
||||
Sample one label for a given sample index
|
||||
:param sample_idx: The sample index for which the label will be sampled
|
||||
:return: None
|
||||
"""
|
||||
return self.random_state.choice(self.num_classes, 1, p=self.distribution[sample_idx])[0]
|
||||
|
||||
def add_noise_to_distribution(self, offset: float, distribution: np.ndarray, noise_model_type: str) -> np.ndarray:
|
||||
from InnerEyeDataQuality.datasets.label_noise_model import get_cifar10_asym_noise_model
|
||||
from InnerEyeDataQuality.datasets.label_noise_model import get_cifar10_sym_noise_model
|
||||
|
||||
# Create noise model
|
||||
if noise_model_type == 'sym':
|
||||
noise_model = get_cifar10_sym_noise_model(eta=1.0)
|
||||
elif noise_model_type == 'asym':
|
||||
noise_model = get_cifar10_asym_noise_model(eta=1.0)
|
||||
else:
|
||||
raise ValueError("Unknown noise model type")
|
||||
np.fill_diagonal(noise_model, 0.0)
|
||||
noise_model *= offset
|
||||
|
||||
# Add this noise on top of every sample in the dataset
|
||||
for _ii in range(self.num_samples):
|
||||
true_label = np.argmax(self.label_counts[_ii])
|
||||
distribution[_ii] += noise_model[true_label]
|
||||
|
||||
# Normalise the distribution
|
||||
return distribution / np.sum(distribution, axis=1, keepdims=True)
|
|
@ -0,0 +1,86 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
|
||||
from default_paths import CIFAR10_ROOT_DIR
|
||||
from InnerEyeDataQuality.datasets.cifar10h import CIFAR10H
|
||||
from InnerEyeDataQuality.selection.simulation_statistics import get_ambiguous_sample_ids
|
||||
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
|
||||
def get_cifar10h_confusion_matrix(temperature: float = 1.0, only_difficult_cases: bool = False) -> np.ndarray:
|
||||
"""
|
||||
Generates a class confusion matrix based on the label distribution in CIFAR10H.
|
||||
"""
|
||||
cifar10h_labels = CIFAR10H.download_cifar10h_labels(str(CIFAR10_ROOT_DIR))
|
||||
|
||||
if only_difficult_cases:
|
||||
ambiguous_sample_ids = get_ambiguous_sample_ids(cifar10h_labels)
|
||||
cifar10h_labels = cifar10h_labels[ambiguous_sample_ids, :]
|
||||
|
||||
# Temperature scale the original distribution
|
||||
if temperature > 1.0:
|
||||
orig_distribution = cifar10h_labels / np.sum(cifar10h_labels, axis=1, keepdims=True)
|
||||
_d = np.power(orig_distribution, 1. / temperature)
|
||||
scaled_distribution = _d / np.sum(_d, axis=1, keepdims=True)
|
||||
sample_counts = (scaled_distribution * np.sum(cifar10h_labels, axis=1, keepdims=True)).astype(np.int64)
|
||||
else:
|
||||
sample_counts = cifar10h_labels
|
||||
|
||||
y_pred, y_true = list(), list()
|
||||
for image_index in range(sample_counts.shape[0]):
|
||||
image_label_counts = sample_counts[image_index]
|
||||
for _iter, _el in enumerate(image_label_counts.tolist()):
|
||||
y_pred.extend([_iter] * _el)
|
||||
y_true.extend([np.argmax(image_label_counts)] * np.sum(image_label_counts))
|
||||
cm = confusion_matrix(y_true, y_pred, normalize="true")
|
||||
|
||||
return cm
|
||||
|
||||
|
||||
def get_cifar10_asym_noise_model(eta: float = 0.3) -> np.ndarray:
|
||||
"""
|
||||
CLASS-DEPENDENT ASYMMETRIC LABEL NOISE
|
||||
https://proceedings.neurips.cc/paper/2018/file/f2925f97bc13ad2852a7a551802feea0-Paper.pdf
|
||||
TRUCK -> AUTOMOBILE, BIRD -> AIRPLANE, DEER -> HORSE, CAT -> DOG, and DOG -> CAT
|
||||
|
||||
:param eta: The likelihood of true label switching from one of the specified classes to nearest class.
|
||||
In other words, likelihood of introducing a class-dependent label noise
|
||||
"""
|
||||
|
||||
# Generate a noise transition matrix.
|
||||
assert (0.0 <= eta) and (eta <= 1.0)
|
||||
|
||||
eps = 1e-12
|
||||
num_classes = 10
|
||||
conf_mat = np.eye(N=num_classes)
|
||||
indices = [[2, 0], [9, 1], [5, 3], [3, 5], [4, 7]]
|
||||
for ind in indices:
|
||||
conf_mat[ind[0], ind[1]] = eta / (1.0 - eta + eps)
|
||||
return conf_mat / np.sum(conf_mat, axis=1, keepdims=True)
|
||||
|
||||
|
||||
def get_cifar10_sym_noise_model(eta: float = 0.3) -> np.ndarray:
|
||||
"""
|
||||
Symmetric LABEL NOISE
|
||||
:param eta: The likelihood of true label switching from true class to rest of the classes.
|
||||
"""
|
||||
# Generate a noise transition matrix.
|
||||
assert (0.0 <= eta) and (eta <= 1.0)
|
||||
assert isinstance(eta, float)
|
||||
|
||||
num_classes = 10
|
||||
conf_mat = np.eye(N=num_classes)
|
||||
for ind in range(num_classes):
|
||||
conf_mat[ind, ind] -= eta
|
||||
other_classes = np.setdiff1d(range(num_classes), ind)
|
||||
for o_c in other_classes:
|
||||
conf_mat[ind, o_c] += eta / other_classes.size
|
||||
|
||||
assert np.all(np.abs(np.sum(conf_mat, axis=1) - 1.0) < 1e-9)
|
||||
|
||||
return conf_mat
|
|
@ -0,0 +1,106 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Tuple, Dict, Union
|
||||
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
NIH_TOTAL_SIZE = 112120
|
||||
|
||||
|
||||
class NIHCXR(Dataset):
|
||||
def __init__(self,
|
||||
data_directory: str,
|
||||
use_training_split: bool,
|
||||
seed: int = 1234,
|
||||
shuffle: bool = True,
|
||||
transform: Optional[Callable] = None,
|
||||
num_samples: int = None,
|
||||
return_index: bool = True) -> None:
|
||||
"""
|
||||
Class for the full NIH ChestXray Dataset (112k images)
|
||||
|
||||
:param data_directory: the directory containing all training images from the dataset as well as the
|
||||
Data_Entry_2017.csv file containing the dataset labels.
|
||||
:param use_training_split: whether to return the training or the test split of the dataset.
|
||||
:param seed: random seed to use for dataset creation
|
||||
:param shuffle: whether to shuffle the dataset prior to spliting between validation and training
|
||||
:param transform: a preprocessing function that takes a PIL image as input and returns a tensor
|
||||
:param num_samples: number of the samples to return (has to been smaller than the dataset split)
|
||||
"""
|
||||
self.data_directory = Path(data_directory)
|
||||
if not self.data_directory.exists():
|
||||
logging.error(
|
||||
f"The data directory {self.data_directory} does not exist. Make sure to download the NIH data "
|
||||
f"first.The dataset can on the main page"
|
||||
"https://www.kaggle.com/nih-chest-xrays/data. Make sure all images are placed directly under the "
|
||||
"data_directory folder. Make sure you downloaded the Data_Entry_2017.csv file to this directory as"
|
||||
"well.")
|
||||
|
||||
self.train = use_training_split
|
||||
self.seed = seed
|
||||
self.random_state = np.random.RandomState(seed)
|
||||
self.dataset_dataframe = pd.read_csv(self.data_directory / "Data_Entry_2017.csv")
|
||||
self.dataset_dataframe["pneumonia_like"] = self.dataset_dataframe["Finding Labels"].apply(
|
||||
lambda x: x.split("|")).apply(lambda x: "pneumonia" in x.lower()
|
||||
or "infiltration" in x.lower()
|
||||
or "consolidation" in x.lower())
|
||||
self.transforms = transform
|
||||
|
||||
orig_labels = self.dataset_dataframe.pneumonia_like.values.astype(np.int64)
|
||||
subjects_ids = self.dataset_dataframe["Image Index"].values
|
||||
is_train_ids = self.dataset_dataframe["train"].values
|
||||
self.num_classes = 2
|
||||
self.indices = np.where(is_train_ids)[0] if use_training_split else np.where(~is_train_ids)[0]
|
||||
self.indices = self.random_state.permutation(self.indices) \
|
||||
if shuffle else self.indices
|
||||
# ------------- Select subset of current split ------------- #
|
||||
if num_samples is not None:
|
||||
assert 0 < num_samples <= len(self.indices)
|
||||
self.indices = self.indices[:num_samples]
|
||||
|
||||
self.subject_ids = subjects_ids[self.indices]
|
||||
self.orig_labels = orig_labels[self.indices].reshape(-1)
|
||||
self.targets = self.orig_labels
|
||||
|
||||
# Identify case ids for ambiguous and clear label noise cases
|
||||
self.ambiguity_metric_args: Dict = dict()
|
||||
|
||||
dataset_type = "TRAIN" if use_training_split else "VAL"
|
||||
logging.info(f"Proportion of positive labels - {dataset_type}: {np.mean(self.targets)}")
|
||||
logging.info(f"Number samples - {dataset_type}: {self.targets.shape[0]}")
|
||||
self.return_index = return_index
|
||||
|
||||
def __getitem__(self, index: int) -> Union[Tuple[int, PIL.Image.Image, int], Tuple[PIL.Image.Image, int]]:
|
||||
"""
|
||||
|
||||
:param index: The index of the sample to be fetched
|
||||
:return: The image and label tensors
|
||||
"""
|
||||
subject_id = self.subject_ids[index]
|
||||
filename = self.data_directory / f"{subject_id}"
|
||||
target = self.targets[index]
|
||||
scan_image = Image.open(filename).convert("L")
|
||||
if self.transforms is not None:
|
||||
scan_image = self.transforms(scan_image)
|
||||
if self.return_index:
|
||||
return index, scan_image, int(target)
|
||||
return scan_image, int(target)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
||||
:return: The size of the dataset
|
||||
"""
|
||||
return len(self.indices)
|
||||
|
||||
def get_label_names(self) -> List[str]:
|
||||
return ["NotPneunomiaLike", "PneunomiaLike"]
|
|
@ -0,0 +1,112 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_nih_dataframe(mapping_file_path: Path) -> pd.DataFrame:
|
||||
"""
|
||||
This function loads the json file mapping NIH ids to Kaggle images.
|
||||
Loads the original NIH label (multiple labels for each image).
|
||||
Then it creates the grouping by NIH categories (pneumonia, pneumonia like,
|
||||
other disease, no finding).
|
||||
:param mapping_file_path: path to the json mapping from NIH to Kaggle dataset (on the
|
||||
RSNA webpage)
|
||||
:return: dataframe with original NIH labels for each patient in the Kaggle dataset.
|
||||
"""
|
||||
with open(mapping_file_path) as f:
|
||||
list_subjects = json.load(f)
|
||||
orig_dataset = pd.DataFrame(columns=["subject", "orig_label"])
|
||||
orig_dataset["subject"] = [l["subset_img_id"] for l in list_subjects] # noqa: E741
|
||||
orig_labels = [str(l['orig_labels']).lower() for l in list_subjects] # noqa: E741
|
||||
orig_dataset["nih_pneumonia"] = ["pneumonia" in l for l in orig_labels] # noqa: E741
|
||||
orig_dataset["nih_pneumonia_like"] = [(("infiltration" in l or "consolidation" in l) and ("pneumonia" not in l)) for
|
||||
l in orig_labels] # noqa: E741
|
||||
orig_dataset["no_finding"] = ["no finding" in str(l).lower() for l in orig_labels] # noqa: E741
|
||||
orig_dataset["orig_label"] = orig_labels
|
||||
orig_dataset["orig_label"].apply(lambda x: sorted(x))
|
||||
orig_dataset[
|
||||
"nih_other_disease"] = ~orig_dataset.nih_pneumonia_like & ~orig_dataset.nih_pneumonia & ~orig_dataset.no_finding
|
||||
orig_dataset[
|
||||
"nih_category"] = 1 * orig_dataset.nih_pneumonia + 2 * orig_dataset.no_finding + 3 * \
|
||||
orig_dataset.nih_other_disease
|
||||
orig_dataset["StudyInstanceUID"] = [l["StudyInstanceUID"] for l in list_subjects] # noqa: E741
|
||||
orig_dataset.nih_category = orig_dataset.nih_category.apply(lambda x: ["Consolidation/Infiltration", "Pneumonia",
|
||||
"No finding", "Other disease"][x])
|
||||
return orig_dataset
|
||||
|
||||
|
||||
def process_detailed_probs_dataset(detailed_probs_path: Path) -> pd.DataFrame:
|
||||
"""
|
||||
This function loads the csv file with the detailed information for each bounding boxes as annotated
|
||||
by the readers during the adjudication for the preparation of the challenge. It maps low, medium and high
|
||||
probabilities label to a numerical scale from 1 to 3. Computes the minimum, maximum, average confidence for each
|
||||
patient for which at least one bounding box was present.
|
||||
:param detailed_probs_path: path to detailed_probs csv file released in Kaggle challenge.
|
||||
:return: dataframe with metrics for confidence in bounding boxes by patient.
|
||||
"""
|
||||
conversion_map = {"Lung Opacity (Low Prob)": 1, "Lung Opacity (Med Prob)": 2, "Lung Opacity (High Prob)": 3}
|
||||
detailed_probs_dataset = pd.read_csv(detailed_probs_path)
|
||||
detailed_probs_dataset["ClassProb"] = detailed_probs_dataset["labelName"].apply(lambda x: conversion_map[x])
|
||||
process_details = detailed_probs_dataset.groupby("StudyInstanceUID")["ClassProb"].agg(
|
||||
[np.mean, np.min, np.max, np.count_nonzero, list])
|
||||
process_details.rename(columns={"mean": "avg_conf_score", "amin": "min_conf_score", "amax": "max_conf_score"},
|
||||
inplace=True)
|
||||
return process_details
|
||||
|
||||
|
||||
def create_mapping_dataset_nih(mapping_file_path: Path,
|
||||
kaggle_dataset_path: Path,
|
||||
detailed_class_info_path: Path,
|
||||
detailed_probs_path: Path) -> pd.DataFrame:
|
||||
"""
|
||||
Creates the final chest x-ray dataset combining labels from NIH, kaggle and detailed information about kaggle
|
||||
labels from the detailed_class_info and detailed_probs csv file released during the challenge.
|
||||
:param mapping_file_path:
|
||||
:param kaggle_dataset_path:
|
||||
:param detailed_class_info_path:
|
||||
:param detailed_probs_path:
|
||||
:return: detailed dataset
|
||||
"""
|
||||
orig_dataset = create_nih_dataframe(mapping_file_path)
|
||||
kaggle_dataset = pd.read_csv(kaggle_dataset_path)
|
||||
difficulty = pd.read_csv(detailed_class_info_path).drop_duplicates()
|
||||
detailed_probs_dataset = process_detailed_probs_dataset(detailed_probs_path)
|
||||
# Merge NIH info with Kaggle dataset
|
||||
merged = pd.merge(orig_dataset, kaggle_dataset)
|
||||
merged.rename(columns={"label": "label_kaggle"}, inplace=True)
|
||||
# Define binary label from original NIH label, for consolidation/infiltration
|
||||
# mapping is not clear, use kaggle label.
|
||||
merged.loc[merged.nih_pneumonia, "binary_nih_initial_label"] = True
|
||||
merged.loc[merged.no_finding | merged.nih_other_disease, "binary_nih_initial_label"] = False
|
||||
merged.loc[merged.nih_pneumonia_like, "binary_nih_initial_label"] = merged.loc[
|
||||
merged.nih_pneumonia_like, "label_kaggle"]
|
||||
# Add subclass information from Kaggle challenge to define ambiguous cases
|
||||
merged = pd.merge(merged, difficulty, left_on="subject", right_on="patientId")
|
||||
merged["not_normal"] = (merged["class"] == "No Lung Opacity / Not Normal")
|
||||
# Add difficulty information from Kaggle based on presence of bounding boxes
|
||||
merged = pd.merge(merged, detailed_probs_dataset, on="StudyInstanceUID", how="left")
|
||||
merged.drop(columns=["patientId"], inplace=True)
|
||||
merged.fillna(-1, inplace=True)
|
||||
# Ambiguous if there was only low probability boxes and the adjudicated label is true
|
||||
merged.loc[merged.label_kaggle, ["ambiguous"]] = (merged.loc[merged.label_kaggle].min_conf_score == 1) & (
|
||||
merged.loc[merged.label_kaggle].max_conf_score == 1)
|
||||
# Ambiguous if there was some bounding boxes but adjudicated label is false
|
||||
merged.loc[~merged.label_kaggle, ["ambiguous"]] = merged.loc[~merged.label_kaggle].min_conf_score > -1
|
||||
return merged
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from default_paths import INNEREYE_DQ_DIR
|
||||
current_dir = Path(__file__).parent
|
||||
mapping_file = Path(__file__).parent / "pneumonia-challenge-dataset-mappings_2018.json"
|
||||
kaggle_dataset_path = current_dir / "stage_2_train_labels.csv"
|
||||
detailed_class_info = current_dir / "stage_2_detailed_class_info.csv"
|
||||
detailed_probs = current_dir / "RSNA_pneumonia_all_probs.csv"
|
||||
dataset = create_mapping_dataset_nih(mapping_file, kaggle_dataset_path, detailed_class_info, detailed_probs)
|
||||
dataset.to_csv(INNEREYE_DQ_DIR / "datasets" / "noisy_chestxray_dataset.csv", index=False)
|
|
@ -0,0 +1,183 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pydicom as dicom
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from InnerEyeDataQuality.datasets.label_distribution import LabelDistribution
|
||||
from InnerEyeDataQuality.selection.simulation_statistics import SimulationStats
|
||||
from InnerEyeDataQuality.utils.generic import convert_labels_to_one_hot
|
||||
from InnerEyeDataQuality.evaluation.metrics import compute_label_entropy
|
||||
|
||||
|
||||
class NoisyKaggleSubsetCXR(Dataset):
|
||||
def __init__(self, data_directory: str,
|
||||
use_training_split: bool,
|
||||
consolidation_noise_rate: float,
|
||||
train_fraction: float = 0.5,
|
||||
seed: int = 1234,
|
||||
shuffle: bool = True,
|
||||
transform: Optional[Callable] = None,
|
||||
num_samples: Optional[int] = None,
|
||||
use_noisy_fixed_labels: bool = True) -> None:
|
||||
"""
|
||||
Class for the noisy Kaggle RSNA Pneumonia Detection Dataset. This dataset uses the kaggle dataset with noisy
|
||||
labels
|
||||
as the original labels from RSNA and the clean labels are the Kaggle labels.
|
||||
|
||||
:param data_directory: the directory containing all training images from the Challenge (stage 1) as well as the
|
||||
dataset.csv containing the kaggle and the original labels.
|
||||
:param use_training_split: whether to return the training or the validation split of the dataset.
|
||||
:param train_fraction: the proportion of samples to use for training
|
||||
:param seed: random seed to use for dataset creation
|
||||
:param shuffle: whether to shuffle the dataset prior to spliting between validation and training
|
||||
:param transform: a preprocessing function that takes a PIL image as input and returns a tensor
|
||||
:param num_samples: number of the samples to return (has to been smaller than the dataset split)
|
||||
:param use_noisy_fixed_labels: if True use the original labels as the initial labels else use the clean labels.
|
||||
:param consolidation_noise_rate: proportion of noisy samples among consolidation/infiltration NIH category.
|
||||
"""
|
||||
dataset_type = "TRAIN" if use_training_split else "VAL"
|
||||
self.data_directory = Path(data_directory)
|
||||
if not self.data_directory.exists():
|
||||
raise RuntimeError(
|
||||
f"The data directory {self.data_directory} does not exist. Make sure to download to Kaggle data "
|
||||
f"first.The kaggle dataset can "
|
||||
"be acceded via the Kaggle CLI kaggle competitions download -c rsna-pneumonia-detection-challenge or "
|
||||
"on the main page of the challenge "
|
||||
"https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data?select=stage_2_train_images")
|
||||
|
||||
path_to_noisy_csv = Path(__file__).parent / "noisy_chestxray_dataset.csv"
|
||||
if not path_to_noisy_csv.exists():
|
||||
raise RuntimeError(f"The noisy dataset csv can not be found in {path_to_noisy_csv}, make sure to run "
|
||||
"create_noisy_chestxray_dataset.py first. See readme for more detailed instructions on the pre-requisite"
|
||||
" for running the noisy Chest Xray benchmark.")
|
||||
|
||||
self.train = use_training_split
|
||||
self.train_fraction = train_fraction
|
||||
self.seed = seed
|
||||
self.random_state = np.random.RandomState(seed)
|
||||
self.dataset_dataframe = pd.read_csv(str(Path(__file__).parent / "noisy_chestxray_dataset.csv"))
|
||||
self.transforms = transform
|
||||
|
||||
self.dataset_dataframe["initial_label"] = self.dataset_dataframe.binary_nih_initial_label
|
||||
|
||||
# Random uniform noise among consolidation
|
||||
pneumonia_like_subj = self.dataset_dataframe.loc[self.dataset_dataframe.nih_pneumonia_like, "subject"].values
|
||||
selected = self.random_state.choice(pneumonia_like_subj,
|
||||
replace=False,
|
||||
size=int(len(pneumonia_like_subj) * consolidation_noise_rate))
|
||||
self.dataset_dataframe.loc[self.dataset_dataframe.subject.isin(selected), "initial_label"] = \
|
||||
~self.dataset_dataframe.loc[self.dataset_dataframe.subject.isin(selected), "initial_label"]
|
||||
|
||||
self.dataset_dataframe["is_noisy"] = self.dataset_dataframe.label_kaggle != self.dataset_dataframe.initial_label
|
||||
|
||||
initial_labels = self.dataset_dataframe.initial_label.values.astype(np.int64).reshape(-1, 1)
|
||||
kaggle_labels = self.dataset_dataframe.label_kaggle.values.astype(np.int64).reshape(-1, 1)
|
||||
subjects_ids = self.dataset_dataframe.subject.values
|
||||
is_ambiguous = self.dataset_dataframe.ambiguous.values
|
||||
orig_label = self.dataset_dataframe.orig_label.values
|
||||
nih_category = self.dataset_dataframe.nih_category.values
|
||||
|
||||
# Convert clean labels to one-hot to populate label counts
|
||||
# i.e for easy cases assume the true distribution is 100% ground truth
|
||||
kaggle_label_counts = convert_labels_to_one_hot(kaggle_labels, n_classes=2)
|
||||
# For ambiguous cases: [0, 1] -> [1, 2] and [1, 0] -> [2, 1]
|
||||
kaggle_label_counts[is_ambiguous, :] = kaggle_label_counts[is_ambiguous, :] * 2 + 1
|
||||
_, self.num_classes = kaggle_label_counts.shape
|
||||
assert self.num_classes == 2
|
||||
|
||||
# ------------- Split the data into training and validation sets ------------- #
|
||||
self.num_datapoints = len(self.dataset_dataframe)
|
||||
all_indices = np.arange(self.num_datapoints)
|
||||
num_samples_set1 = int(self.num_datapoints * self.train_fraction)
|
||||
all_indices = self.random_state.permutation(all_indices) \
|
||||
if shuffle else all_indices
|
||||
train_indices = all_indices[:num_samples_set1]
|
||||
val_indices = all_indices[num_samples_set1:]
|
||||
self.indices = train_indices if use_training_split else val_indices
|
||||
|
||||
# ------------- Select subset of current split ------------- #
|
||||
# If n_samples is set to restrict dataset i.e. for data_curation
|
||||
num_samples = self.num_datapoints if num_samples is None else num_samples
|
||||
if num_samples < self.num_datapoints:
|
||||
assert 0 < num_samples <= len(self.indices)
|
||||
self.indices = self.indices[:num_samples]
|
||||
|
||||
# ------------ Finalize dataset --------------- #
|
||||
self.subject_ids = subjects_ids[self.indices]
|
||||
|
||||
# Label distribution is constructed from the true labels
|
||||
self.label_counts = kaggle_label_counts[self.indices]
|
||||
self.label_distribution = LabelDistribution(seed, self.label_counts)
|
||||
|
||||
self.initial_labels = initial_labels[self.indices].reshape(-1)
|
||||
self.kaggle_labels = kaggle_labels[self.indices].reshape(-1)
|
||||
self.targets = self.initial_labels if use_noisy_fixed_labels else self.kaggle_labels
|
||||
self.orig_labels = orig_label[self.indices]
|
||||
self.is_ambiguous = is_ambiguous[self.indices]
|
||||
self.nih_category = nih_category[self.indices]
|
||||
|
||||
# Identify case ids for ambiguous and clear label noise cases
|
||||
label_stats = SimulationStats(name="NoisyChestXray", true_label_counts=self.label_counts,
|
||||
initial_labels=convert_labels_to_one_hot(self.targets, self.num_classes))
|
||||
self.clear_mislabeled_cases = label_stats.mislabelled_not_ambiguous_sample_ids[0]
|
||||
self.ambiguous_mislabelled_cases = label_stats.mislabelled_ambiguous_sample_ids[0]
|
||||
self.true_label_entropy = compute_label_entropy(label_counts=self.label_counts)
|
||||
self.ambiguity_metric_args = {"ambiguous_mislabelled_ids": self.ambiguous_mislabelled_cases,
|
||||
"clear_mislabelled_ids": self.clear_mislabeled_cases,
|
||||
"true_label_entropy": self.true_label_entropy}
|
||||
self.num_samples = self.targets.shape[0]
|
||||
logging.info(self.num_samples)
|
||||
logging.info(len(self.targets))
|
||||
logging.info(len(self.indices))
|
||||
logging.info(f"Proportion of positive clean labels - {dataset_type}: {np.mean(self.kaggle_labels)}")
|
||||
logging.info(f"Proportion of positive noisy labels - {dataset_type}: {np.mean(self.targets)}")
|
||||
logging.info(
|
||||
f"Total noise rate on the {dataset_type} dataset: {np.mean(self.kaggle_labels != self.targets)} \n")
|
||||
selected_df = self.dataset_dataframe.loc[self.dataset_dataframe.subject.isin(self.subject_ids)]
|
||||
noisy_df = selected_df.loc[selected_df.is_noisy]
|
||||
noisy_df["nih_noise"] = ~noisy_df.nih_pneumonia_like
|
||||
logging.info(f"\n{pd.crosstab(noisy_df.nih_noise, noisy_df.ambiguous).to_string()}")
|
||||
# self.weight = np.mean(self.kaggle_labels)
|
||||
# logging.info(f"Weight negative {self.weight:.2f} - weight positive {(1 - self.weight):.2f}")
|
||||
self.png_files = (self.data_directory / f"{self.subject_ids[0]}.png").exists()
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[int, PIL.Image.Image, int]:
|
||||
"""
|
||||
|
||||
:param index: The index of the sample to be fetched
|
||||
:return: The image and label tensors
|
||||
"""
|
||||
subject_id = self.subject_ids[index]
|
||||
target = self.targets[index]
|
||||
if self.png_files:
|
||||
filename = self.data_directory / f"{subject_id}.png"
|
||||
scan_image = Image.open(filename)
|
||||
else:
|
||||
filename = self.data_directory / f"{subject_id}.dcm"
|
||||
scan_image = dicom.dcmread(filename).pixel_array
|
||||
scan_image = Image.fromarray(scan_image)
|
||||
if self.transforms is not None:
|
||||
scan_image = self.transforms(scan_image)
|
||||
if scan_image.shape == 2:
|
||||
scan_image = scan_image.unsqueeze(dim=0)
|
||||
return index, scan_image, int(target)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
||||
:return: The size of the dataset
|
||||
"""
|
||||
return len(self.subject_ids)
|
||||
|
||||
def get_label_names(self) -> List[str]:
|
||||
return ["Normal", "Opacity"]
|
|
@ -0,0 +1,45 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from math import inf
|
||||
from scipy import stats
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_instance_noise_model(n: float, dataset: Any, labels: Union[List, torch.Tensor], num_classes: int,
|
||||
feature_size: int, norm_std: float, seed: int) -> np.ndarray:
|
||||
"""
|
||||
:param n: noise_rate
|
||||
:param dataset: cifar10 # not train_loader
|
||||
:param labels: labels (targets)
|
||||
:param num_classes: class number
|
||||
:param feature_size: the size of input images (e.g. 28*28)
|
||||
:param norm_std: default 0.1
|
||||
:param seed: random_seed
|
||||
"""
|
||||
if isinstance(labels, list):
|
||||
labels = torch.FloatTensor(labels)
|
||||
P = []
|
||||
random_state = np.random.RandomState(seed)
|
||||
flip_distribution = stats.truncnorm((0 - n) / norm_std, (1 - n) / norm_std, loc=n, scale=norm_std)
|
||||
flip_rate = flip_distribution.rvs(labels.shape[0], random_state=seed)
|
||||
|
||||
W = random_state.randn(num_classes, feature_size, num_classes)
|
||||
W = torch.FloatTensor(W)
|
||||
|
||||
for i, (x, y) in enumerate(dataset):
|
||||
# (1 x M) * (M x 10) = (1 x 10)
|
||||
A = x.view(1, -1).mm(W[y]).squeeze(0)
|
||||
A[y] = -inf
|
||||
A = flip_rate[i] * F.softmax(A, dim=0)
|
||||
A[y] += 1 - flip_rate[i]
|
||||
P.append(A)
|
||||
P = torch.stack(P, 0).numpy()
|
||||
|
||||
return P
|
|
@ -0,0 +1,5 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
"""
|
||||
DenseNet121 as implement in torchvision
|
||||
"""
|
||||
|
||||
def __init__(self, config: ConfigNode) -> None:
|
||||
super().__init__()
|
||||
self.densenet121 = torchvision.models.densenet121(pretrained=config.train.pretrained, progress=False)
|
||||
num_ftrs = self.densenet121.classifier.in_features
|
||||
self.densenet121.classifier = nn.Linear(num_ftrs, config.dataset.n_classes)
|
||||
self.projection = self.densenet121.classifier
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.densenet121(x)
|
||||
return x
|
|
@ -0,0 +1,26 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
"""
|
||||
Resnet50 as implement in torchvision
|
||||
"""
|
||||
|
||||
def __init__(self, config: ConfigNode) -> None:
|
||||
super().__init__()
|
||||
self.resnet = torchvision.models.resnet50(pretrained=config.train.pretrained, progress=False)
|
||||
num_ftrs = self.resnet.fc.in_features
|
||||
self.resnet.fc = nn.Linear(num_ftrs, config.dataset.n_classes)
|
||||
self.projection = self.resnet.fc
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.resnet(x)
|
||||
return x
|
|
@ -0,0 +1,117 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class EMA:
|
||||
"""Exponential moving average of model parameters.
|
||||
Args:
|
||||
model (torch.nn.Module): Model with parameters whose EMA will be kept.
|
||||
decay (float): Decay rate for exponential moving average.
|
||||
step_max (int): Maximum required number of steps to reach specified decay rate from zero.
|
||||
In the initial epochs, decay rate is kept small to keep the teacher model up-to-date.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, decay: float = 0.99, step_max: int = 150) -> None:
|
||||
self.decay_max = decay
|
||||
self.step_max = step_max
|
||||
self.step_count = int(0)
|
||||
self.shadow = {}
|
||||
self.original: Dict[str, torch.Tensor] = {}
|
||||
self.model = model # reference to the student model
|
||||
|
||||
# Register model parameters
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.shadow[name] = param.data.clone()
|
||||
|
||||
logging.info("Creating a teacher model (EMA)")
|
||||
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Receives a new set of parameter values and merges them with the previously stored ones.
|
||||
"""
|
||||
self.step_count += int(1)
|
||||
decay = self._get_decay_rate(self.step_count)
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
assert name in self.shadow
|
||||
new_average = (1.0 - decay) * param.data + decay * self.shadow[name]
|
||||
self.shadow[name] = new_average.clone()
|
||||
|
||||
def _get_decay_rate(self, step: int) -> float:
|
||||
"""
|
||||
Return decay rate for current stored parameters
|
||||
"""
|
||||
if step <= self.step_max:
|
||||
ratio = step / self.step_max
|
||||
return 0.5 * self.decay_max * (1 - np.cos(ratio * np.pi))
|
||||
else:
|
||||
return self.decay_max
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Runs an inference using the template of student model
|
||||
"""
|
||||
is_train = self.model.training
|
||||
self._assign()
|
||||
self.model.eval()
|
||||
self._switch_hooks(False)
|
||||
outputs = self.model.forward(inputs).detach()
|
||||
self._switch_hooks(True)
|
||||
self._restore()
|
||||
self.model.train() if is_train else self.model.eval()
|
||||
|
||||
return outputs
|
||||
|
||||
def _switch_hooks(self, bool_value: bool) -> None:
|
||||
for layer in self.model.children():
|
||||
if hasattr(layer, "use_hook"):
|
||||
layer.use_hook = bool_value # type: ignore
|
||||
|
||||
def _assign(self) -> None:
|
||||
"""Assign exponential moving average of parameter values to the
|
||||
respective parameters.
|
||||
Args:
|
||||
model (torch.nn.Module): Model to assign parameter values.
|
||||
"""
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
assert name in self.shadow
|
||||
self.original[name] = param.data.clone()
|
||||
param.data = self.shadow[name]
|
||||
|
||||
def _restore(self) -> None:
|
||||
"""Restore original parameters to a model. That is, put back
|
||||
the values that were in each parameter at the last call to `assign`.
|
||||
Args:
|
||||
model (torch.nn.Module): Model to assign parameter values.
|
||||
"""
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
assert name in self.shadow
|
||||
param.data = self.original[name]
|
||||
|
||||
def save_model(self, save_path: str) -> None:
|
||||
"""
|
||||
Saves EMA model parameters to a checkpoint file.
|
||||
"""
|
||||
self._assign()
|
||||
state = {'ema_model': self.model.state_dict(),
|
||||
'ema_params': self.shadow,
|
||||
'step_count': self.step_count}
|
||||
torch.save(state, save_path)
|
||||
self._restore()
|
||||
|
||||
def restore_from_checkpoint(self, path: str) -> None:
|
||||
state = torch.load(path)
|
||||
self.shadow = state['ema_params']
|
||||
self.step_count = state['step_count']
|
|
@ -0,0 +1,22 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Lambda(nn.Module):
|
||||
"""
|
||||
Lambda torch nn module that can be used to modularise nn.functional methods.
|
||||
"""
|
||||
|
||||
def __init__(self, fn: Callable) -> None:
|
||||
super(Lambda, self).__init__()
|
||||
self.lambda_func = fn
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self.lambda_func(input)
|
|
@ -0,0 +1,84 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.ssl_classifier_module import PretrainedClassifier, SSLClassifier
|
||||
|
||||
|
||||
class CollectCNNEmbeddings:
|
||||
"""
|
||||
This class takes care of registering a forward hook to get the embeddings for a given model.
|
||||
"""
|
||||
|
||||
def __init__(self, use_only_in_train: bool, store_input_tensor: bool) -> None:
|
||||
"""
|
||||
:param use_only_in_train: If set to True, hooks are registered only for model forward passes in training mode.
|
||||
:param store_input_tensor:
|
||||
"""
|
||||
self.inputs: list = list()
|
||||
self.layer: Optional[torch.nn.Module] = None
|
||||
self.use_only_in_train = use_only_in_train
|
||||
self.store_input_tensor = store_input_tensor
|
||||
|
||||
def __call__(self,
|
||||
module: torch.nn.Module,
|
||||
module_in: List[torch.Tensor],
|
||||
module_out: torch.Tensor) -> None:
|
||||
|
||||
# If model is in validation state and only training time collection is allowed, then exit.
|
||||
if self.use_only_in_train and not module.training:
|
||||
return
|
||||
if module.use_hook:
|
||||
_tensor = module_in[0] if self.store_input_tensor else module_out
|
||||
self.inputs.append(_tensor.detach().cpu())
|
||||
|
||||
def return_embeddings(self, return_numpy: int = True) -> Union[np.ndarray, torch.Tensor, None]:
|
||||
if len(self.inputs) == 0:
|
||||
return None
|
||||
embeddings = torch.cat(self.inputs, dim=0)
|
||||
return embeddings.cpu().numpy() if return_numpy else embeddings
|
||||
|
||||
def reset(self) -> None:
|
||||
self.inputs = list()
|
||||
|
||||
|
||||
def register_embeddings_collector(models: List[torch.nn.Module],
|
||||
use_only_in_train: bool = False) -> List[CollectCNNEmbeddings]:
|
||||
"""
|
||||
Takes a list of models and register a foward hook for each model
|
||||
:param models: Torch module
|
||||
:param use_only_in_train: If set to True, hooks are registered only for model forward passes in training mode.
|
||||
"""
|
||||
assert(isinstance(use_only_in_train, bool))
|
||||
all_model_cnn_embeddings = []
|
||||
for model in models:
|
||||
store_input_tensor = False if isinstance(model, SSLClassifier) else True
|
||||
cnn_embeddings = CollectCNNEmbeddings(use_only_in_train, store_input_tensor)
|
||||
if hasattr(model, "projection"):
|
||||
cnn_embeddings.layer = model.projection # type: ignore
|
||||
elif hasattr(model, "resnet"):
|
||||
cnn_embeddings.layer = model.resnet.fc # type: ignore
|
||||
elif isinstance(model, PretrainedClassifier):
|
||||
cnn_embeddings.layer = model.classifier_head
|
||||
else:
|
||||
cnn_embeddings.layer = model.fc # type: ignore
|
||||
cnn_embeddings.layer.use_hook = True # type: ignore
|
||||
cnn_embeddings.layer.register_forward_hook(cnn_embeddings) # type: ignore
|
||||
all_model_cnn_embeddings.append(cnn_embeddings)
|
||||
return all_model_cnn_embeddings
|
||||
|
||||
|
||||
def get_all_embeddings(embeddings_collectors: List[CollectCNNEmbeddings]) -> List[np.ndarray]:
|
||||
"""
|
||||
Returns all embeddings from a list of embeddings collectors and resets the list
|
||||
"""
|
||||
output = list()
|
||||
for cnn_embeddings in embeddings_collectors:
|
||||
output.append(cnn_embeddings.return_embeddings(return_numpy=True))
|
||||
cnn_embeddings.reset()
|
||||
return output
|
|
@ -0,0 +1,83 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, Callable, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torchvision
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
from InnerEyeDataQuality.deep_learning.transforms import AddGaussianNoise, CenterCrop, ElasticTransform, \
|
||||
ExpandChannels, RandomAffine, RandomColorJitter, RandomErasing, RandomGamma, \
|
||||
RandomHorizontalFlip, RandomResizeCrop, Resize
|
||||
|
||||
|
||||
def _get_dataset_stats(
|
||||
config: ConfigNode) -> Tuple[np.ndarray, np.ndarray]:
|
||||
name = config.dataset.name
|
||||
if name == 'CIFAR10':
|
||||
mean = np.array([0.4914, 0.4822, 0.4465])
|
||||
std = np.array([0.2470, 0.2435, 0.2616])
|
||||
else:
|
||||
raise ValueError()
|
||||
return mean, std
|
||||
|
||||
|
||||
def create_transform(config: ConfigNode, is_train: bool) -> Callable:
|
||||
if config.dataset.name in ["NoisyChestXray", "Kaggle"]:
|
||||
return create_chest_xray_transform(config, is_train)
|
||||
elif config.dataset.name in ["CIFAR10", "CIFAR10H", "CIFAR10IDN", "CIFAR10H_TRAIN_VAL", "CIFAR10SYM"]:
|
||||
return create_cifar_transform(config, is_train)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def create_cifar_transform(config: ConfigNode,
|
||||
is_train: bool) -> Callable:
|
||||
transforms: List[Any] = list()
|
||||
if is_train:
|
||||
if config.augmentation.use_random_affine:
|
||||
transforms.append(RandomAffine(config))
|
||||
if config.augmentation.use_random_crop:
|
||||
transforms.append(RandomResizeCrop(config))
|
||||
if config.augmentation.use_random_horizontal_flip:
|
||||
transforms.append(RandomHorizontalFlip(config))
|
||||
if config.augmentation.use_random_color:
|
||||
transforms.append(RandomColorJitter(config))
|
||||
transforms += [ToTensor()]
|
||||
return torchvision.transforms.Compose(transforms)
|
||||
|
||||
|
||||
def create_chest_xray_transform(config: ConfigNode,
|
||||
is_train: bool) -> Callable:
|
||||
"""
|
||||
Defines the image transformations pipeline for Chest-Xray datasets.
|
||||
"""
|
||||
transforms: List[Any] = []
|
||||
if is_train:
|
||||
if config.augmentation.use_random_affine:
|
||||
transforms.append(RandomAffine(config))
|
||||
if config.augmentation.use_random_crop:
|
||||
transforms.append(RandomResizeCrop(config))
|
||||
else:
|
||||
transforms.append(Resize(config))
|
||||
if config.augmentation.use_random_horizontal_flip:
|
||||
transforms.append(RandomHorizontalFlip(config))
|
||||
if config.augmentation.use_gamma_transform:
|
||||
transforms.append(RandomGamma(config))
|
||||
if config.augmentation.use_random_color:
|
||||
transforms.append(RandomColorJitter(config))
|
||||
if config.augmentation.use_elastic_transform:
|
||||
transforms.append(ElasticTransform(config))
|
||||
transforms += [CenterCrop(config), ToTensor()]
|
||||
if config.augmentation.use_random_erasing:
|
||||
transforms.append(RandomErasing(config))
|
||||
if config.augmentation.add_gaussian_noise:
|
||||
transforms.append(AddGaussianNoise(config))
|
||||
else:
|
||||
transforms += [Resize(config), CenterCrop(config), ToTensor()]
|
||||
transforms.append(ExpandChannels())
|
||||
return torchvision.transforms.Compose(transforms)
|
|
@ -0,0 +1,61 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yacs.config
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
|
||||
|
||||
def get_number_of_samples_per_epoch(dataloader: torch.utils.data.DataLoader) -> int:
|
||||
"""
|
||||
Returns the expected number of samples for a single epoch
|
||||
"""
|
||||
total_num_samples = len(dataloader.dataset) # type: ignore
|
||||
batch_size = dataloader.batch_size
|
||||
drop_last = dataloader.drop_last
|
||||
num_samples = int(total_num_samples / batch_size) * batch_size if drop_last else total_num_samples # type:ignore
|
||||
return num_samples
|
||||
|
||||
|
||||
def get_train_dataloader(train_dataset: VisionDataset,
|
||||
config: yacs.config.CfgNode,
|
||||
seed: int,
|
||||
**kwargs: Any) -> DataLoader:
|
||||
if config.train.use_balanced_sampler:
|
||||
counts = np.bincount(train_dataset.targets)
|
||||
class_weights = counts.sum() / counts
|
||||
sample_weights = class_weights[train_dataset.targets]
|
||||
sample = torch.utils.data.WeightedRandomSampler(weights=sample_weights, num_samples=len(train_dataset))
|
||||
kwargs.pop("shuffle", None)
|
||||
kwargs.update({"sampler": sample})
|
||||
return torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.train.batch_size,
|
||||
num_workers=config.train.dataloader.num_workers,
|
||||
pin_memory=config.train.dataloader.pin_memory,
|
||||
worker_init_fn=WorkerInitFunc(seed),
|
||||
**kwargs)
|
||||
|
||||
|
||||
class WorkerInitFunc:
|
||||
def __init__(self, seed: int) -> None:
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int) -> None:
|
||||
return np.random.seed(self.seed + worker_id)
|
||||
|
||||
|
||||
def get_val_dataloader(val_dataset: VisionDataset, config: yacs.config.CfgNode, seed: int) -> DataLoader:
|
||||
return torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.validation.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=config.validation.dataloader.num_workers,
|
||||
pin_memory=config.validation.dataloader.pin_memory,
|
||||
worker_init_fn=WorkerInitFunc(seed))
|
|
@ -0,0 +1,111 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import torch
|
||||
from sklearn.neighbors import kneighbors_graph
|
||||
|
||||
from InnerEyeDataQuality.algorithms.graph import GraphParameters, build_connectivity_graph, label_diffusion
|
||||
from InnerEyeDataQuality.utils.generic import convert_labels_to_one_hot, find_set_difference_torch
|
||||
|
||||
|
||||
class GraphClassifier:
|
||||
"""
|
||||
Graph based classifier. Builds a graph and runs label diffusion to classify new points.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_samples: int,
|
||||
num_classes: int,
|
||||
labels: np.ndarray,
|
||||
device: torch.device) -> None:
|
||||
self.graph = None
|
||||
self.device = device
|
||||
self.num_samples = num_samples
|
||||
self.num_classes = num_classes
|
||||
self.graph_params = GraphParameters(n_neighbors=12,
|
||||
diffusion_alpha=0.95,
|
||||
cg_solver_max_iter=10,
|
||||
diffusion_batch_size=None,
|
||||
distance_kernel="cosine")
|
||||
|
||||
# Convert to one-hot label distribution
|
||||
self.labels = np.array(labels) if isinstance(labels, list) else labels
|
||||
self.one_hot_labels = convert_labels_to_one_hot(self.labels, n_classes=num_classes)
|
||||
assert np.all(self.one_hot_labels.sum(axis=1) == 1.0)
|
||||
assert self.labels.shape[0] == num_samples
|
||||
|
||||
def build_graph(self, embeddings: np.ndarray) -> None:
|
||||
logging.info("Building a new connectivity graph")
|
||||
assert embeddings.shape[0] == self.num_samples
|
||||
|
||||
# Build a connectivity graph and k-nearest neighbours.
|
||||
n_neighbors = self.graph_params.n_neighbors
|
||||
self.knn = kneighbors_graph(embeddings, n_neighbors, metric=self.graph_params.distance_kernel, n_jobs=-1)
|
||||
self.graph = build_connectivity_graph(normalised=True,
|
||||
embeddings=embeddings,
|
||||
n_neighbors=n_neighbors,
|
||||
distance_kernel=self.graph_params.distance_kernel)
|
||||
laplacian = scipy.sparse.eye(self.num_samples) - self.graph_params.diffusion_alpha * self.graph # type: ignore
|
||||
self.laplacian_inv = scipy.sparse.linalg.inv(laplacian.tocsc()).todense()
|
||||
|
||||
def fit(self, query_batch_ids: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Run label diffusion and identify potential labels for query samples
|
||||
"""
|
||||
diffused_labels = label_diffusion(inv_laplacian=self.laplacian_inv,
|
||||
labels=self.one_hot_labels,
|
||||
query_batch_ids=query_batch_ids,
|
||||
diffusion_normalizing_factor=0.01)
|
||||
assert np.all(diffused_labels.shape == (query_batch_ids.size, self.num_classes))
|
||||
assert not np.isnan(diffused_labels).any()
|
||||
return diffused_labels
|
||||
|
||||
def filter_cases(self,
|
||||
local_ind_keep: torch.Tensor,
|
||||
local_ind_exclude: torch.Tensor,
|
||||
global_ind: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Filter list of cases to drop based on diffused labels. If labels and diffused labels agree, do not
|
||||
exclude the sample.
|
||||
:param local_ind_keep: original list of samples to keep
|
||||
:param local_ind_exclude: original list of samples to drop
|
||||
:param global_ind: list of global indices
|
||||
:return: updated list of indices to keep and drop
|
||||
"""
|
||||
# If input indices are empty return an empty tensor
|
||||
if torch.numel(local_ind_exclude) == 0:
|
||||
return local_ind_keep, local_ind_exclude
|
||||
|
||||
# Check input variable consistency
|
||||
num_samples = torch.numel(global_ind)
|
||||
assert num_samples == torch.numel(local_ind_exclude) + torch.numel(local_ind_keep)
|
||||
global_ind = global_ind.to(local_ind_keep.device)
|
||||
all_local_ind = torch.tensor(range(num_samples), device=local_ind_keep.device, dtype=local_ind_keep.dtype)
|
||||
|
||||
# Run graph diffusion to filter out incorrectly picked indices.
|
||||
global_ind_exclude = global_ind[local_ind_exclude].cpu().numpy()
|
||||
diffused_probs = self.fit(global_ind_exclude)
|
||||
graph_pred = np.argmax(diffused_probs, axis=1)
|
||||
initial_labels = self.labels[global_ind_exclude]
|
||||
|
||||
# Update the local indices for exclude
|
||||
local_ind_exclude_updated = local_ind_exclude[graph_pred != initial_labels]
|
||||
local_ind_keep_updated = find_set_difference_torch(all_local_ind, local_ind_exclude_updated)
|
||||
|
||||
return local_ind_keep_updated, local_ind_exclude_updated
|
||||
|
||||
def compute_mingling_index(self, indices: np.ndarray) -> None:
|
||||
"""
|
||||
Computes mingling index of each graph node based on label distribution in local graphs.
|
||||
"""
|
||||
mingling = np.zeros(indices.shape, dtype=np.float)
|
||||
for loop_id, _ind in enumerate(indices):
|
||||
disagreement = self.labels[self.knn[_ind].indices] != self.labels[_ind]
|
||||
mingling[loop_id] = np.sum(disagreement) / float(disagreement.size)
|
|
@ -0,0 +1,75 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Union
|
||||
|
||||
import yacs
|
||||
import torch
|
||||
from torch import Tensor as T
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def tanh_loss(logits: T) -> T:
|
||||
"""
|
||||
Penalises sparse logits with large values
|
||||
"Bootstrap Your Own Latent A New Approach to Self-Supervised Learning" Supplementary C.1
|
||||
"""
|
||||
alpha = 10
|
||||
tclip = alpha * torch.tanh(logits / alpha)
|
||||
return torch.mean(torch.pow(tclip, 2.0))
|
||||
|
||||
# Computes consistency loss between unlabelled or excluded points
|
||||
def consistency_loss(logits_source: T, logits_target: T) -> Union[T, float]:
|
||||
"""
|
||||
Class probability consistency loss based on total variation.
|
||||
"""
|
||||
if (logits_source.numel() > 0) & (logits_source.shape == logits_target.shape):
|
||||
_prob_source = torch.softmax(logits_source, dim=-1)
|
||||
_prob_target = torch.softmax(logits_target, dim=-1)
|
||||
loss = torch.mean(torch.norm(_prob_source - _prob_target.detach(), p=2, dim=-1))
|
||||
return loss
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def early_regularisation_loss(student_logits: torch.Tensor, ema_logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Implements early regularisation loss term proposed in:
|
||||
https://arxiv.org/abs/2007.00151
|
||||
"""
|
||||
posteriors_teacher = torch.softmax(ema_logits, dim=-1).detach()
|
||||
posteriors_student = torch.softmax(student_logits, dim=-1)
|
||||
inner_prod = torch.sum(posteriors_teacher * posteriors_student, dim=-1)
|
||||
early_regularisation = torch.mean(torch.log(1 - inner_prod))
|
||||
|
||||
return early_regularisation
|
||||
|
||||
def onehot_encoding(label: torch.Tensor, n_classes: int) -> torch.Tensor:
|
||||
return torch.zeros(label.size(0), n_classes).to(label.device).scatter_(1, label.view(-1, 1), 1)
|
||||
|
||||
class CrossEntropyLoss:
|
||||
"""
|
||||
Cross entropy loss - implements label smoothing
|
||||
"""
|
||||
|
||||
def __init__(self, config: yacs.config.CfgNode):
|
||||
self.n_classes = config.dataset.n_classes
|
||||
self.use_label_smoothing = config.augmentation.use_label_smoothing
|
||||
self.epsilon = config.augmentation.label_smoothing.epsilon
|
||||
|
||||
def __call__(self, predictions: torch.Tensor, targets: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
|
||||
if self.use_label_smoothing:
|
||||
device = predictions.device
|
||||
onehot = onehot_encoding(targets, self.n_classes).type_as(predictions).to(device)
|
||||
targets = onehot * (1 - self.epsilon) + torch.ones_like(onehot).to(device) * self.epsilon / self.n_classes
|
||||
logp = F.log_softmax(predictions, dim=1)
|
||||
loss_per_sample = torch.sum(-logp * targets, dim=1)
|
||||
if reduction == 'none':
|
||||
return loss_per_sample
|
||||
elif reduction == 'mean':
|
||||
return loss_per_sample.mean()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
return F.cross_entropy(predictions, targets, reduction=reduction)
|
|
@ -0,0 +1,175 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from InnerEyeDataQuality.datasets.cifar10_utils import get_cifar10_label_names
|
||||
from InnerEyeDataQuality.deep_learning.metrics.sample_metrics import SampleMetrics
|
||||
from InnerEyeDataQuality.deep_learning.metrics.plots_tensorboard import (get_scatter_plot, plot_disagreement_per_sample,
|
||||
plot_excluded_cases_coteaching)
|
||||
from InnerEyeDataQuality.deep_learning.transforms import ToNumpy
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
|
||||
@dataclass()
|
||||
class JointMetrics():
|
||||
"""
|
||||
Stores metrics for co-teaching models.
|
||||
"""
|
||||
num_samples: int
|
||||
num_epochs: int
|
||||
dataset: Optional[Any] = None
|
||||
ambiguous_mislabelled_ids: np.ndarray = None
|
||||
clear_mislabelled_ids: np.ndarray = None
|
||||
true_label_entropy: np.ndarray = None
|
||||
plot_dropped_images: bool = False
|
||||
|
||||
def reset(self) -> None:
|
||||
self.kl_divergence_symmetric = np.full([self.num_samples], np.nan)
|
||||
self.active = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.reset()
|
||||
self.prediction_disagreement = np.zeros([self.num_samples, self.num_epochs], dtype=np.bool)
|
||||
self._initialise_dataset_properties()
|
||||
self.case_drop_histogram = np.zeros([self.num_samples, self.num_epochs + 1], dtype=np.bool)
|
||||
if not isinstance(self.clear_mislabelled_ids, np.ndarray) or not isinstance(self.ambiguous_mislabelled_ids,
|
||||
np.ndarray):
|
||||
return
|
||||
|
||||
self.true_mislabelled_ids = np.concatenate([self.ambiguous_mislabelled_ids, self.clear_mislabelled_ids], axis=0)
|
||||
self.case_drop_histogram[self.clear_mislabelled_ids, -1] = True
|
||||
self.case_drop_histogram[self.ambiguous_mislabelled_ids, -1] = True
|
||||
|
||||
def _initialise_dataset_properties(self) -> None:
|
||||
if self.dataset is not None:
|
||||
self.label_names = get_cifar10_label_names() if isinstance(self.dataset, CIFAR10) \
|
||||
else self.dataset.get_label_names()
|
||||
self.dataset.transform = ToNumpy() # type: ignore
|
||||
|
||||
def log_results(self, writer: SummaryWriter, epoch: int, sample_metrics: SampleMetrics) -> None:
|
||||
if (not self.active) or (self.ambiguous_mislabelled_ids is None) or (self.clear_mislabelled_ids is None):
|
||||
return
|
||||
|
||||
# KL Divergence between the two posteriors
|
||||
writer.add_scalars(main_tag='symmetric-kl-divergence', tag_scalar_dict={
|
||||
'all': np.nanmean(self.kl_divergence_symmetric),
|
||||
'ambiguous': np.nanmean(self.kl_divergence_symmetric[self.ambiguous_mislabelled_ids]),
|
||||
'clear_noise': np.nanmean(self.kl_divergence_symmetric[self.clear_mislabelled_ids])},
|
||||
global_step=epoch)
|
||||
|
||||
# Disagreement rate between the models
|
||||
writer.add_scalars(main_tag='disagreement_rate', tag_scalar_dict={
|
||||
'all': np.nanmean(self.prediction_disagreement[:, epoch]),
|
||||
'ambiguous': np.nanmean(self.prediction_disagreement[self.ambiguous_mislabelled_ids, epoch]),
|
||||
'clear_noise': np.nanmean(self.prediction_disagreement[self.clear_mislabelled_ids, epoch])},
|
||||
global_step=epoch)
|
||||
|
||||
# Add histogram for the loss values
|
||||
self.log_loss_values(writer, sample_metrics.loss_per_sample[:, epoch], epoch)
|
||||
|
||||
# Add disagreement metrics
|
||||
fig = get_scatter_plot(self.true_label_entropy, self.kl_divergence_symmetric,
|
||||
x_label="Label entropy", y_label="Symmetric-KL", y_lim=[0.0, 2.0])
|
||||
writer.add_figure('Sym-KL vs Label Entropy', figure=fig, global_step=epoch, close=True)
|
||||
|
||||
fig = plot_disagreement_per_sample(self.prediction_disagreement, self.true_label_entropy)
|
||||
writer.add_figure('Disagreement of prediction', figure=fig, global_step=epoch, close=True)
|
||||
|
||||
# Excluded cases diagnostics
|
||||
self.log_dropped_cases_metrics(writer=writer, epoch=epoch)
|
||||
|
||||
# Every 10 epochs, display the dropped cases in the co-teaching algorithm
|
||||
if epoch % 10 and self.plot_dropped_images:
|
||||
self.log_dropped_images(writer=writer, predictions=sample_metrics.predictions, epoch=epoch)
|
||||
|
||||
# Close all figures
|
||||
plt.close('all')
|
||||
|
||||
def log_dropped_cases_metrics(self, writer: SummaryWriter, epoch: int) -> None:
|
||||
"""
|
||||
Creates all diagnostics for dropped cases analysis.
|
||||
"""
|
||||
entropy_sorted_indices = np.argsort(self.true_label_entropy)
|
||||
drop_cur_epoch_mask = self.case_drop_histogram[:, epoch]
|
||||
drop_cur_epoch_ids = np.where(drop_cur_epoch_mask)[0]
|
||||
is_sample_dropped = np.any(drop_cur_epoch_mask)
|
||||
title = None
|
||||
if is_sample_dropped:
|
||||
n_dropped = float(drop_cur_epoch_ids.size)
|
||||
average_label_entropy_dropped_cases = np.mean(self.true_label_entropy[drop_cur_epoch_mask])
|
||||
n_detected_mislabelled = np.intersect1d(drop_cur_epoch_ids, self.true_mislabelled_ids).size
|
||||
n_clean_dropped = int(n_dropped - n_detected_mislabelled)
|
||||
n_detected_mislabelled_ambiguous = np.intersect1d(drop_cur_epoch_ids, self.ambiguous_mislabelled_ids).size
|
||||
n_detected_mislabelled_clear = np.intersect1d(drop_cur_epoch_ids, self.clear_mislabelled_ids).size
|
||||
perc_detected_mislabelled = n_detected_mislabelled / n_dropped * 100
|
||||
perc_detected_clear_mislabelled = n_detected_mislabelled_clear / n_dropped * 100
|
||||
perc_detected_ambiguous_mislabelled = n_detected_mislabelled_ambiguous / n_dropped * 100
|
||||
title = f"Dropped Cases: Avg label entropy {average_label_entropy_dropped_cases:.3f}\n " \
|
||||
f"Dropped cases: {n_detected_mislabelled} mislabelled ({perc_detected_mislabelled:.1f}%) - " \
|
||||
f"{n_clean_dropped} clean ({(100 - perc_detected_mislabelled):.1f}%)\n" \
|
||||
f"Num ambiguous mislabelled among detected cases: {n_detected_mislabelled_ambiguous}" \
|
||||
f" ({perc_detected_ambiguous_mislabelled:.1f}%)\n" \
|
||||
f"Num clear mislabelled among detected cases: {n_detected_mislabelled_clear}" \
|
||||
f" ({perc_detected_clear_mislabelled:.1f}%)"
|
||||
writer.add_scalars(main_tag='Number of dropped cases', tag_scalar_dict={
|
||||
'clean_cases': n_clean_dropped,
|
||||
'all_mislabelled_cases': n_detected_mislabelled,
|
||||
'mislabelled_clear_cases': n_detected_mislabelled_clear,
|
||||
'mislabelled_ambiguous_cases': n_detected_mislabelled_ambiguous}, global_step=epoch)
|
||||
writer.add_scalar(tag="Percentage of mislabelled among dropped cases",
|
||||
scalar_value=perc_detected_mislabelled, global_step=epoch)
|
||||
fig = plot_excluded_cases_coteaching(case_drop_mask=self.case_drop_histogram,
|
||||
entropy_sorted_indices=entropy_sorted_indices, title=title,
|
||||
num_epochs=self.num_epochs, num_samples=self.num_samples)
|
||||
writer.add_figure('Histogram of excluded cases', figure=fig, global_step=epoch, close=True)
|
||||
|
||||
def log_loss_values(self, writer: SummaryWriter, loss_values: np.ndarray, epoch: int) -> None:
|
||||
"""
|
||||
Logs histogram of loss values of one of the co-teaching models.
|
||||
"""
|
||||
writer.add_histogram('loss/all', loss_values, epoch)
|
||||
writer.add_histogram('loss/ambiguous_noise', loss_values[self.ambiguous_mislabelled_ids], epoch)
|
||||
writer.add_histogram('loss/clear_noise', loss_values[self.clear_mislabelled_ids], epoch)
|
||||
|
||||
def log_dropped_images(self, writer: SummaryWriter, predictions: np.ndarray, epoch: int) -> None:
|
||||
"""
|
||||
Logs images dropped during co-teaching training
|
||||
"""
|
||||
dropped_cases = np.where(self.case_drop_histogram[:, epoch])[0]
|
||||
if dropped_cases.size > 0 and self.dataset is not None:
|
||||
dropped_cases = dropped_cases[np.argsort(self.true_label_entropy[dropped_cases])]
|
||||
fig = self.plot_batch_images_and_labels(predictions, list_indices=dropped_cases[:64])
|
||||
writer.add_figure("Dropped images with lowest entropy", figure=fig, global_step=epoch, close=True)
|
||||
fig = self.plot_batch_images_and_labels(predictions, list_indices=dropped_cases[-64:])
|
||||
writer.add_figure("Dropped images with highest entropy", figure=fig, global_step=epoch, close=True)
|
||||
|
||||
kept_cases = np.where(~self.case_drop_histogram[:, epoch])[0]
|
||||
kept_cases = kept_cases[np.argsort(self.true_label_entropy[kept_cases])]
|
||||
fig = self.plot_batch_images_and_labels(predictions, kept_cases[-64:])
|
||||
writer.add_figure("Kept images with highest entropy", figure=fig, global_step=epoch, close=True)
|
||||
|
||||
def plot_batch_images_and_labels(self, predictions: np.ndarray, list_indices: np.ndarray) -> plt.Figure:
|
||||
"""
|
||||
Plots of batch of images along with their labels and predictions. Noise cases are colored in red, clean cases
|
||||
in green. Images are assumed to be numpy images (use ToNumpy() transform).
|
||||
"""
|
||||
assert self.dataset is not None
|
||||
fig, ax = plt.subplots(8, 8, figsize=(8, 10))
|
||||
ax = ax.ravel()
|
||||
for i, index in enumerate(list_indices):
|
||||
predicted = int(predictions[index])
|
||||
color = "red" if index in self.true_mislabelled_ids else "green"
|
||||
_, img, training_label = self.dataset.__getitem__(index)
|
||||
ax[i].imshow(img)
|
||||
ax[i].set_axis_off()
|
||||
ax[i].set_title(f"Label: {self.label_names[training_label]}\nPred: {self.label_names[predicted]}",
|
||||
color=color, fontsize="x-small")
|
||||
return fig
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Optional, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.use('Agg') # No display
|
||||
|
||||
|
||||
def get_scatter_plot(x_data: np.ndarray, y_data: np.ndarray, scale: np.ndarray = None,
|
||||
title: str = '', x_label: str = '', y_label: str = '',
|
||||
y_lim: Optional[List[float]] = None) -> plt.Figure:
|
||||
fig, ax = plt.subplots()
|
||||
ax.scatter(x_data, y_data, alpha=0.3, s=scale)
|
||||
ax.set_xlabel(x_label)
|
||||
ax.set_ylabel(y_label)
|
||||
ax.set_title(title)
|
||||
ax.grid()
|
||||
if y_lim:
|
||||
ax.set_ylim(y_lim)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def get_histogram_plot(data: Union[List[np.ndarray], np.ndarray], num_bins: int, title: str = '',
|
||||
x_label: str = '', x_lim: Tuple[float, float] = None) -> plt.Figure:
|
||||
"""
|
||||
Creates a histogram plot for a given set of numpy arrays specified in `data` object.
|
||||
Return the generated figure object.
|
||||
"""
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
cm = plt.cm.get_cmap('Pastel1')
|
||||
if isinstance(data, list):
|
||||
for _d_id, _d in enumerate(data):
|
||||
ax.hist(_d, density=True, bins=num_bins, color=cm(_d_id), alpha=0.6)
|
||||
else:
|
||||
ax.hist(data, density=True, bins=num_bins)
|
||||
ax.set_ylabel('Sample density')
|
||||
ax.set_xlabel(x_label)
|
||||
ax.set_title(title)
|
||||
ax.grid()
|
||||
|
||||
if x_lim:
|
||||
ax.set_xlim(x_lim)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plot_excluded_cases_coteaching(case_drop_mask: np.ndarray,
|
||||
entropy_sorted_indices: np.ndarray,
|
||||
num_epochs: int,
|
||||
num_samples: int,
|
||||
title: Optional[str] = None) -> plt.Figure:
|
||||
"""
|
||||
Plots the excluded cases in co-teaching training - training epochs vs sample_ids
|
||||
Samples are sorted based on their true ambiguity score.
|
||||
"""
|
||||
fig, ax = plt.subplots(figsize=(15, 10))
|
||||
ax.imshow(case_drop_mask[entropy_sorted_indices, :].astype(np.uint8) * 255, cmap="gray",
|
||||
extent=[0, num_epochs, 0, num_samples], vmin=0, vmax=10, aspect='auto')
|
||||
ax.set_xlabel("Number of epochs")
|
||||
ax.set_ylabel("Training sample ids (ordered by true entropy)")
|
||||
if title:
|
||||
ax.set_title(title)
|
||||
return fig
|
||||
|
||||
|
||||
def plot_disagreement_per_sample(prediction_disagreement: np.ndarray, true_label_entropy: np.ndarray) -> plt.Figure:
|
||||
"""
|
||||
Plots predicted class disagreement between two models - training epochs vs sample_ids
|
||||
Samples are sorted based on their true ambiguity score.
|
||||
"""
|
||||
entropy_sorted_indices = np.argsort(true_label_entropy)
|
||||
fig, ax = plt.subplots()
|
||||
num_epochs = prediction_disagreement.shape[1]
|
||||
num_samples = prediction_disagreement.shape[0]
|
||||
ax.imshow(prediction_disagreement[entropy_sorted_indices, :].astype(np.uint8) * 2, cmap="gray",
|
||||
extent=[0, num_epochs, 0, num_samples], vmin=0, vmax=1, aspect='auto')
|
||||
ax.set_xlabel("Number of epochs")
|
||||
ax.set_ylabel("Training sample ids (ordered by true entropy)")
|
||||
return fig
|
|
@ -0,0 +1,170 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.metrics import auc, f1_score, precision_recall_curve, roc_auc_score
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@dataclass
|
||||
class SampleMetrics():
|
||||
"""
|
||||
Stores data required for training monitoring of individual model and
|
||||
post-training analysis in sample selection
|
||||
"""
|
||||
name: str
|
||||
num_epochs: int
|
||||
num_samples: int
|
||||
num_classes: int
|
||||
clear_labels: np.ndarray = None
|
||||
ambiguous_mislabelled_ids: np.ndarray = None
|
||||
clear_mislabelled_ids: np.ndarray = None
|
||||
embeddings_size: Optional[int] = None
|
||||
true_label_entropy: Optional[np.ndarray] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.reset()
|
||||
self.loss_per_sample = np.full([self.num_samples, self.num_epochs], np.nan)
|
||||
self.logits_per_sample = np.full([self.num_samples, self.num_classes, self.num_epochs], np.nan)
|
||||
|
||||
if self.clear_mislabelled_ids is not None:
|
||||
full_set = set(range(self.num_samples))
|
||||
if self.ambiguous_mislabelled_ids is not None:
|
||||
self.all_mislabelled_ids = np.concatenate(
|
||||
[self.ambiguous_mislabelled_ids, self.clear_mislabelled_ids], 0)
|
||||
else:
|
||||
self.all_mislabelled_ids = self.clear_mislabelled_ids
|
||||
self.all_clean_label_ids = np.array(list(full_set.difference(self.all_mislabelled_ids)))
|
||||
|
||||
def reset(self) -> None:
|
||||
self.loss_optimised: List[float] = list()
|
||||
self.predictions = np.full([self.num_samples], np.nan)
|
||||
self.labels = np.full([self.num_samples], np.nan)
|
||||
self.probabilities = np.full([self.num_samples, self.num_classes], np.nan)
|
||||
self.correct_predictions_teacher = np.full(self.num_samples, np.nan)
|
||||
self.correct_predictions = np.full([self.num_samples], np.nan)
|
||||
self.embeddings_per_sample = np.full([self.num_samples, self.embeddings_size],
|
||||
np.nan) if self.embeddings_size is not None else None
|
||||
|
||||
def get_average_accuracy(self) -> np.ndarray:
|
||||
return np.nanmean(self.correct_predictions)
|
||||
|
||||
def get_average_loss(self) -> np.ndarray:
|
||||
return np.mean(self.loss_optimised)
|
||||
|
||||
def log_results(self, epoch: int, name: str, writer: Optional[SummaryWriter] = None) -> None:
|
||||
mean_loss = self.get_average_loss()
|
||||
accuracy = self.get_average_accuracy()
|
||||
accuracy_on_clean = np.nanmean(self.predictions == self.clear_labels)
|
||||
logging.info(f'{self.name} \t accuracy: {accuracy:.2f} \t loss: {mean_loss:.2e}')
|
||||
if writer is None:
|
||||
return
|
||||
# Store accuracy and loss metrics
|
||||
writer.add_scalar(tag='loss', scalar_value=mean_loss, global_step=epoch)
|
||||
writer.add_scalar(tag='accuracy/acc_on_sampled_labels', scalar_value=accuracy, global_step=epoch)
|
||||
writer.add_scalar(tag='accuracy/acc_on_clean_labels', scalar_value=accuracy_on_clean, global_step=epoch)
|
||||
|
||||
# Binary classification case
|
||||
if self.num_classes == 2:
|
||||
logits = self.logits_per_sample[:, :, epoch]
|
||||
labels = self.labels
|
||||
predictions = np.argmax(logits, axis=-1)
|
||||
available_indices = np.where(~np.isnan(self.labels))[0]
|
||||
roc_auc = roc_auc_score(labels[available_indices], logits[available_indices, 1].reshape(-1))
|
||||
f1 = f1_score(labels[available_indices], predictions[available_indices])
|
||||
precision, recall, _ = precision_recall_curve(labels[available_indices], logits[available_indices, 1].reshape(-1))
|
||||
pr_auc = auc(recall, precision)
|
||||
writer.add_scalar(tag='roc_auc', scalar_value=roc_auc, global_step=epoch)
|
||||
writer.add_scalar(tag='f1_score', scalar_value=f1, global_step=epoch)
|
||||
writer.add_scalar(tag='pr_auc', scalar_value=pr_auc, global_step=epoch)
|
||||
logging.info(f'{name} \t roc_auc: {roc_auc: .2f} \t pr_auc: {pr_auc: .2f} \t f1_score: {f1: .2f}')
|
||||
if self.clear_mislabelled_ids is not None:
|
||||
get_sub_f1 = lambda ind: f1_score(labels[ind], predictions[ind]) if len(ind) > 0 else 0
|
||||
clean_available = np.intersect1d(self.all_clean_label_ids, available_indices)
|
||||
mislabelled_available = np.intersect1d(self.all_mislabelled_ids, available_indices)
|
||||
scalar_dict = {
|
||||
'clean_cases': get_sub_f1(clean_available),
|
||||
'all_mislabelled_cases': get_sub_f1(mislabelled_available)}
|
||||
writer.add_scalars(main_tag='f1_breakdown', tag_scalar_dict=scalar_dict, global_step=epoch)
|
||||
get_sub_auc = lambda ind: roc_auc_score(labels[ind], predictions[ind]) if len(ind) > 0 else 0
|
||||
scalar_dict = {
|
||||
'clean_cases': get_sub_auc(clean_available),
|
||||
'all_mislabelled_cases': get_sub_auc(mislabelled_available)}
|
||||
writer.add_scalars(main_tag='auc_breakdown', tag_scalar_dict=scalar_dict, global_step=epoch)
|
||||
|
||||
# Add histogram for the loss values
|
||||
self.log_loss_values(writer, self.loss_per_sample[:, epoch], epoch)
|
||||
|
||||
# Breakdown of the accuracy on different sample types
|
||||
if self.clear_mislabelled_ids is not None:
|
||||
get_sub_acc = lambda ind: np.nanmean(self.correct_predictions[ind])
|
||||
scalar_dict = {
|
||||
'clean_cases': get_sub_acc(self.all_clean_label_ids),
|
||||
'all_mislabelled_cases': get_sub_acc(self.all_mislabelled_ids),
|
||||
'mislabelled_clear_cases': get_sub_acc(self.clear_mislabelled_ids)}
|
||||
if self.ambiguous_mislabelled_ids is not None:
|
||||
scalar_dict.update({'mislabelled_ambiguous_cases': get_sub_acc(self.ambiguous_mislabelled_ids)})
|
||||
writer.add_scalars(main_tag='accuracy_breakdown', tag_scalar_dict=scalar_dict, global_step=epoch)
|
||||
|
||||
# Log mean teacher's accuracy
|
||||
if not np.isnan(self.correct_predictions_teacher).any():
|
||||
writer.add_scalar("teacher_accuracy", np.nanmean(self.correct_predictions_teacher), epoch)
|
||||
|
||||
def get_margin(self, epoch: int) -> np.ndarray:
|
||||
"""
|
||||
Get the margin for each sample defined as logits(y) - max_{y != t}[logits(t)]
|
||||
"""
|
||||
margin = np.full(self.num_samples, np.nan)
|
||||
logits = self.logits_per_sample[:, :, epoch]
|
||||
for i in range(self.num_samples):
|
||||
label = int(self.labels[i])
|
||||
assigned_logit = logits[i, label]
|
||||
order = np.argsort(logits[i])
|
||||
order = order[order != label]
|
||||
other_max_logits = logits[i, order[-1]]
|
||||
margin[i] = assigned_logit - other_max_logits
|
||||
return margin
|
||||
|
||||
def log_loss_values(self, writer: SummaryWriter, loss_values: np.ndarray, epoch: int) -> None:
|
||||
"""
|
||||
Logs histogram of loss values of one of the co-teaching models.
|
||||
"""
|
||||
writer.add_histogram('loss/all', loss_values, epoch)
|
||||
|
||||
def append_batch(
|
||||
self,
|
||||
epoch: int,
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
loss: float,
|
||||
indices: list,
|
||||
per_sample_loss: torch.Tensor,
|
||||
embeddings: Optional[np.ndarray] = None,
|
||||
teacher_logits: Optional[torch.Tensor] = None) -> None:
|
||||
"""
|
||||
Append stats collected from batch of samples to metrics
|
||||
"""
|
||||
if teacher_logits is not None:
|
||||
self.correct_predictions_teacher[indices] = torch.eq(torch.argmax(teacher_logits, dim=-1),
|
||||
labels).type(torch.float32).cpu()
|
||||
self.correct_predictions[indices] = torch.eq(torch.argmax(logits, dim=-1), labels).type(torch.float32).cpu()
|
||||
self.predictions[indices] = torch.argmax(logits, dim=-1).cpu()
|
||||
self.loss_per_sample[indices, epoch] = per_sample_loss.cpu()
|
||||
self.logits_per_sample[indices, :, epoch] = logits.cpu()
|
||||
self.probabilities[indices, :] = torch.softmax(logits, dim=-1).cpu()
|
||||
self.labels[indices] = labels.cpu()
|
||||
self.loss_optimised.append(loss)
|
||||
|
||||
if embeddings is not None:
|
||||
# We don't know the size of the features in advance.
|
||||
if self.embeddings_size is None:
|
||||
self.embeddings_size = embeddings.shape[1]
|
||||
self.embeddings_per_sample = np.full([self.num_samples, self.embeddings_size], np.nan)
|
||||
assert self.embeddings_per_sample is not None
|
||||
self.embeddings_per_sample[indices, :] = embeddings
|
|
@ -0,0 +1,92 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from InnerEyeDataQuality.deep_learning.metrics.joint_metrics import JointMetrics
|
||||
from InnerEyeDataQuality.deep_learning.metrics.sample_metrics import SampleMetrics
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
class MetricTracker(object):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
output_dir: str,
|
||||
num_epochs: int,
|
||||
num_samples_total: int,
|
||||
num_samples_per_epoch: int,
|
||||
num_classes: int,
|
||||
save_tf_events: bool,
|
||||
dataset: Optional[Dataset] = None,
|
||||
name: str = "default_metric",
|
||||
**sample_info_kwargs: Any):
|
||||
"""
|
||||
Class to track model training metrics.
|
||||
If a co-teaching model is trained, joint model metrics are stored such as disagreement rate and kl divergence.
|
||||
Similarly, it stores loss and logits values on a per sample basis for each epoch for post-training analysis.
|
||||
This stored data can be utilised in data selection simulation.
|
||||
"""
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.output_dir = output_dir
|
||||
self.name = name
|
||||
self.num_classes = num_classes
|
||||
self.num_samples_total = num_samples_total
|
||||
self.num_samples_per_epoch = num_samples_per_epoch
|
||||
clean_targets = dataset.clean_targets if hasattr(dataset, "clean_targets") else None # type: ignore
|
||||
|
||||
self.joint_model_metrics = JointMetrics(num_samples_total, num_epochs, dataset, **sample_info_kwargs)
|
||||
self.sample_metrics = SampleMetrics(name, num_epochs, num_samples_total, num_classes,
|
||||
clear_labels=clean_targets,
|
||||
embeddings_size=None, **sample_info_kwargs)
|
||||
self.writer = SummaryWriter(log_dir=output_dir) if save_tf_events else None
|
||||
|
||||
def reset(self) -> None:
|
||||
self.sample_metrics.reset()
|
||||
self.joint_model_metrics.reset()
|
||||
|
||||
def log_epoch_and_reset(self, epoch: int) -> None:
|
||||
# assert np.count_nonzero(~np.isnan(self.sample_metrics.loss_per_sample[:, epoch])) == self.num_samples_per_epoch
|
||||
self.sample_metrics.log_results(epoch=epoch, name=self.name, writer=self.writer)
|
||||
if self.writer:
|
||||
self.joint_model_metrics.log_results(self.writer, epoch, self.sample_metrics)
|
||||
# Reset epoch metrics
|
||||
self.reset()
|
||||
|
||||
def append_batch_aggregate(self, epoch: int, logits_x: torch.Tensor, logits_y: torch.Tensor,
|
||||
dropped_cases: torch.Tensor, indices: torch.Tensor) -> None:
|
||||
"""
|
||||
Stores the disagreement stats for co-teaching models
|
||||
"""
|
||||
post_x = torch.softmax(logits_x, dim=-1)
|
||||
post_y = torch.softmax(logits_y, dim=-1)
|
||||
sym_kl_per_sample = torch.sum(post_x * torch.log(post_x / post_y) + post_y * torch.log(post_y / post_x), dim=-1)
|
||||
|
||||
pred_x = torch.argmax(logits_x, dim=-1)
|
||||
pred_y = torch.argmax(logits_y, dim=-1)
|
||||
class_pred_disagreement = pred_x != pred_y
|
||||
self.joint_model_metrics.kl_divergence_symmetric[indices] = sym_kl_per_sample.cpu().numpy()
|
||||
self.joint_model_metrics.prediction_disagreement[indices, epoch] = class_pred_disagreement.cpu().numpy()
|
||||
self.joint_model_metrics.case_drop_histogram[dropped_cases.cpu().numpy(), epoch] = True
|
||||
self.joint_model_metrics.active = True
|
||||
|
||||
def save_loss(self) -> None:
|
||||
output_path = os.path.join(self.output_dir, f'{self.name}_training_stats.npz')
|
||||
if hasattr(self.joint_model_metrics, "case_drop_histogram"):
|
||||
np.savez(output_path,
|
||||
loss_per_sample=self.sample_metrics.loss_per_sample,
|
||||
logits_per_sample=self.sample_metrics.logits_per_sample,
|
||||
dropped_cases=self.joint_model_metrics.case_drop_histogram[:, :-1])
|
||||
else:
|
||||
np.savez(output_path,
|
||||
loss_per_sample=self.sample_metrics.loss_per_sample,
|
||||
logits_per_sample=self.sample_metrics.logits_per_sample)
|
|
@ -0,0 +1,69 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
from InnerEyeDataQuality.deep_learning.collect_embeddings import get_all_embeddings, register_embeddings_collector
|
||||
from InnerEyeDataQuality.deep_learning.utils import get_run_config
|
||||
from InnerEyeDataQuality.deep_learning.trainers.vanilla_trainer import VanillaTrainer
|
||||
from InnerEyeDataQuality.deep_learning.trainers.model_trainer_base import ModelTrainer
|
||||
from InnerEyeDataQuality.deep_learning.trainers.co_teaching_trainer import CoTeachingTrainer
|
||||
from InnerEyeDataQuality.deep_learning.dataloader import get_val_dataloader
|
||||
from InnerEyeDataQuality.utils.custom_types import SelectorTypes as ST
|
||||
|
||||
NUM_MULTIPLE_RUNS = 10
|
||||
|
||||
def inference_model(dataloader: Any, model_trainer: ModelTrainer, use_mc_sampling: bool = False) -> Tuple[List, List]:
|
||||
"""
|
||||
Performs an inference pass on a single model
|
||||
:param config:
|
||||
:return:
|
||||
"""
|
||||
# Inference on given dataloader
|
||||
all_model_cnn_embeddings = register_embeddings_collector(model_trainer.models, use_only_in_train=False)
|
||||
trackers = model_trainer.run_inference(dataloader, use_mc_sampling)
|
||||
embs = get_all_embeddings(all_model_cnn_embeddings)
|
||||
probs = [metric_tracker.sample_metrics.probabilities for metric_tracker in trackers]
|
||||
return embs, probs
|
||||
|
||||
|
||||
def inference_ensemble(dataset: Any, config: ConfigNode) -> Tuple[np.ndarray, np.ndarray, np.ndarray, ModelTrainer]:
|
||||
"""
|
||||
Returns:
|
||||
embeddings: 2D numpy array - containing sample embeddings obtained from a CNN.
|
||||
[num_samples, embedding_size]
|
||||
posteriors: 2D numpy array - containing class posteriors obtained from a CNN.
|
||||
[num_samples, num_classes]
|
||||
trainer: Model trainer object built using config file
|
||||
"""
|
||||
|
||||
# Reload the model from config
|
||||
model_trainer_class = CoTeachingTrainer if config.train.use_co_teaching else VanillaTrainer
|
||||
config = get_run_config(config, config.train.seed)
|
||||
model_trainer = model_trainer_class(config=config)
|
||||
model_trainer.load_checkpoints(restore_scheduler=False)
|
||||
|
||||
# Prepare output data structures
|
||||
all_embeddings = []
|
||||
all_posteriors = []
|
||||
|
||||
# Run inference on the given dataset
|
||||
use_mc_sampling = config.model.use_dropout and ST(config.selector.type[0]) == ST.BaldSelector
|
||||
multi_inference = config.train.use_self_supervision or config.model.use_dropout
|
||||
num_runs = NUM_MULTIPLE_RUNS if multi_inference else 1
|
||||
for run_ind in range(num_runs):
|
||||
dataloader = get_val_dataloader(dataset, config, seed=config.train.seed + run_ind)
|
||||
embeddings, posteriors = inference_model(dataloader, model_trainer, use_mc_sampling)
|
||||
all_embeddings.append(embeddings)
|
||||
all_posteriors.append(posteriors)
|
||||
|
||||
# Aggregate results and return
|
||||
embeddings = np.mean(all_embeddings, axis=(0, 1))
|
||||
all_posteriors = np.stack(all_posteriors, axis=0)
|
||||
avg_posteriors = np.mean(all_posteriors, axis=(0, 1))
|
||||
|
||||
return embeddings, avg_posteriors, all_posteriors, model_trainer
|
|
@ -0,0 +1,51 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
class ForgetRateScheduler(object):
|
||||
"""
|
||||
Forget rate scheduler for the co-teaching model
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_epochs: int,
|
||||
forget_rate: float = 0.0,
|
||||
num_gradual: int = 10,
|
||||
num_warmup_epochs: int = 25,
|
||||
start_epoch: int = 0):
|
||||
"""
|
||||
|
||||
:param num_epochs: The total number of training epochs
|
||||
:param forget_rate: The base forget rate
|
||||
:param num_gradual: The number of epochs to gradual increase the forget_rate to its base value
|
||||
:param start_epoch: Allows manually set the start epoch if training is resumed.
|
||||
"""
|
||||
logging.info(f"No samples will be excluded in co-teaching for the first {num_warmup_epochs} epochs.")
|
||||
if num_gradual <= num_warmup_epochs:
|
||||
logging.warning(f"Num gradual {num_gradual} <= num warm up epochs. This argument will be ignored.")
|
||||
assert 0 <= forget_rate < 1.
|
||||
self.forget_rate_schedule = np.ones(num_epochs) * forget_rate
|
||||
self.forget_rate_schedule[:num_gradual] = np.linspace(0, forget_rate, num_gradual)
|
||||
self.forget_rate_schedule[:num_warmup_epochs] = np.zeros(num_warmup_epochs)
|
||||
self.current_epoch = start_epoch
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
Step the current epoch by one
|
||||
:return:
|
||||
"""
|
||||
self.current_epoch += 1
|
||||
|
||||
@property
|
||||
def get_forget_rate(self) -> float:
|
||||
"""
|
||||
|
||||
:return: The current forget rate
|
||||
"""
|
||||
current_epoch = min(self.current_epoch, len(self.forget_rate_schedule) - 1)
|
||||
return float(self.forget_rate_schedule[current_epoch])
|
|
@ -0,0 +1,5 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Tuple, Any
|
||||
|
||||
from InnerEyeDataQuality.deep_learning.architectures.lambda_layer import Lambda
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.ssl_classifier_module import get_encoder_output_dim
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.utils import create_ssl_encoder
|
||||
from torch import Tensor as T
|
||||
from torch import nn
|
||||
|
||||
|
||||
class _MLP(nn.Module):
|
||||
def __init__(self, input_dim: int, hidden_size: int, output_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
self.input_dim = input_dim
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_size, bias=False),
|
||||
nn.BatchNorm1d(hidden_size),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(hidden_size, output_dim, bias=True))
|
||||
|
||||
def forward(self, x: T) -> T:
|
||||
x = self.model(x)
|
||||
return x
|
||||
|
||||
class SSLEncoder(nn.Module):
|
||||
def __init__(self, encoder_name: str, dataset_name: str, use_output_pooling: bool = True):
|
||||
super().__init__()
|
||||
self.cnn_model = create_ssl_encoder(encoder_name=encoder_name, dataset_name=dataset_name)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
||||
self.use_output_pooling = use_output_pooling
|
||||
|
||||
def forward(self, x: T) -> T:
|
||||
x = self.cnn_model(x)
|
||||
x = x[-1] if isinstance(x, list) else x
|
||||
x = self.avgpool(x).view(x.size(0), -1) if self.use_output_pooling else x
|
||||
return x
|
||||
|
||||
def get_output_feature_dim(self) -> int:
|
||||
return get_encoder_output_dim(self)
|
||||
|
||||
class SiameseArm(nn.Module):
|
||||
def __init__(self, *encoder_kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = SSLEncoder(*encoder_kwargs) # Encoder
|
||||
self.projector = _MLP(input_dim=self.encoder.get_output_feature_dim(), hidden_size=2048, output_dim=128)
|
||||
self.predictor = _MLP(input_dim=self.projector.output_dim, hidden_size=128, output_dim=128)
|
||||
self.projector_normalised = nn.Sequential(self.projector,
|
||||
Lambda(lambda x: nn.functional.normalize(x, dim=-1)))
|
||||
|
||||
def forward(self, x: T) -> Tuple[T, T, T]:
|
||||
y = self.encoder(x)
|
||||
z = self.projector(y)
|
||||
h = self.predictor(z)
|
||||
return y, z, h
|
|
@ -0,0 +1,142 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterator, List, Tuple, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.byol.byol_models import SiameseArm
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.byol.byol_moving_average import BYOLMAWeightUpdate
|
||||
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
|
||||
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
||||
from torch import Tensor as T
|
||||
from torch.optim import Adam
|
||||
|
||||
BatchType = Tuple[List, T]
|
||||
|
||||
class BYOLInnerEye(pl.LightningModule):
|
||||
"""
|
||||
Implementation of `Bootstrap Your Own Latent (BYOL) <https://arxiv.org/pdf/2006.07733.pdf>`
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_samples: int,
|
||||
learning_rate: float,
|
||||
batch_size: int,
|
||||
encoder_name: str,
|
||||
warmup_epochs: int,
|
||||
dataset_name: Optional[str] = None,
|
||||
weight_decay: float = 1e-6,
|
||||
**kwargs: Any) -> None:
|
||||
"""
|
||||
Args:
|
||||
num_samples: Number of samples present in training dataset / dataloader.
|
||||
learning_rate: Optimizer learning rate.
|
||||
batch_size: Sample batch size used in gradient updates.
|
||||
encoder_name: Type of CNN encoder used to extract image embeddings. The options are:
|
||||
{'resnet18', 'resnet50', 'resnet101'}.
|
||||
warmup_epochs: Number of epochs for scheduler warm up (linear increase from 0 to base_lr).
|
||||
dataset_name: Name of training dataset - If set to "CIFAR10" then the encoder is adjusted to image size.
|
||||
weight_decay: L2-norm weight decay.
|
||||
"""
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.min_learning_rate = 1e-4
|
||||
self.online_network = SiameseArm(encoder_name, dataset_name)
|
||||
self.target_network = deepcopy(self.online_network)
|
||||
self.weight_callback = BYOLMAWeightUpdate()
|
||||
|
||||
def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
|
||||
# Add callback for user automatically since it's key to BYOL weight update
|
||||
self.weight_callback.on_before_zero_grad(self.trainer, self)
|
||||
|
||||
def forward(self, x: T) -> T: # type: ignore
|
||||
return self.target_network.encoder(x)
|
||||
|
||||
def cosine_loss(self, a: T, b: T) -> T:
|
||||
a = F.normalize(a, dim=-1)
|
||||
b = F.normalize(b, dim=-1)
|
||||
neg_cos_sim = -(a * b).sum(dim=-1).mean()
|
||||
return neg_cos_sim
|
||||
|
||||
def shared_step(self, batch: BatchType, batch_idx: int) -> T:
|
||||
(img_1, img_2), y = batch
|
||||
|
||||
# Image 1 to image 2 loss
|
||||
_, _, h_img1 = self.online_network(img_1)
|
||||
_, _, h_img2 = self.online_network(img_2)
|
||||
with torch.no_grad():
|
||||
_, z_img1, _ = self.target_network(img_1)
|
||||
_, z_img2, _ = self.target_network(img_2)
|
||||
loss = 0.5 * (self.cosine_loss(h_img1, z_img2.detach())
|
||||
+ self.cosine_loss(h_img2, z_img1.detach()))
|
||||
|
||||
return loss
|
||||
|
||||
def training_step(self, batch: BatchType, batch_idx: int) -> T: # type: ignore
|
||||
loss = self.shared_step(batch, batch_idx)
|
||||
self.log_dict({'byol/train_loss': loss, 'byol/tau': self.weight_callback.current_tau})
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch: BatchType, batch_idx: int) -> T: # type: ignore
|
||||
loss = self.shared_step(batch, batch_idx)
|
||||
self.log_dict({'byol/validation_loss': loss})
|
||||
|
||||
return loss
|
||||
|
||||
def setup(self, *args: Any, **kwargs: Any) -> None:
|
||||
global_batch_size = self.trainer.world_size * self.hparams.batch_size # type: ignore
|
||||
self.train_iters_per_epoch = self.hparams.num_samples // global_batch_size # type: ignore
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
# TRICK 1 (Use lars + filter weights)
|
||||
# exclude certain parameters
|
||||
parameters = self.exclude_from_wt_decay(self.online_network.named_parameters(),
|
||||
weight_decay=self.hparams.weight_decay) # type: ignore
|
||||
optimizer = LARSWrapper(Adam(parameters, lr=self.hparams.learning_rate)) # type: ignore
|
||||
|
||||
# Trick 2 (after each step)
|
||||
self.hparams.warmup_epochs = self.hparams.warmup_epochs * self.train_iters_per_epoch # type: ignore
|
||||
max_epochs = self.trainer.max_epochs * self.train_iters_per_epoch
|
||||
|
||||
linear_warmup_cosine_decay = LinearWarmupCosineAnnealingLR(
|
||||
optimizer,
|
||||
warmup_epochs=self.hparams.warmup_epochs, # type: ignore
|
||||
max_epochs=max_epochs,
|
||||
warmup_start_lr=0,
|
||||
eta_min=self.min_learning_rate,
|
||||
)
|
||||
|
||||
scheduler = {'scheduler': linear_warmup_cosine_decay, 'interval': 'step', 'frequency': 1}
|
||||
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def exclude_from_wt_decay(self,
|
||||
named_params: Iterator[Tuple[str, T]],
|
||||
weight_decay: float,
|
||||
skip_list: List[str] = ['bias', 'bn']) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convolution-Linear bias-terms and batch-norm parameters are excluded from l2-norm weight decay regularisation.
|
||||
https://arxiv.org/pdf/2006.07733.pdf Section 3.3 Optimisation and Section F.5.
|
||||
"""
|
||||
params = []
|
||||
excluded_params = []
|
||||
|
||||
for name, param in named_params:
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
elif any(layer_name in name for layer_name in skip_list):
|
||||
excluded_params.append(param)
|
||||
else:
|
||||
params.append(param)
|
||||
|
||||
return [
|
||||
{'params': params, 'weight_decay': weight_decay},
|
||||
{'params': excluded_params, 'weight_decay': 0.}
|
||||
]
|
|
@ -0,0 +1,58 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Callback
|
||||
|
||||
|
||||
class BYOLMAWeightUpdate(Callback):
|
||||
"""
|
||||
Weight updates for BYOL moving average encoder (e.g. teacher). Pl_module is expected to contain three attributes:
|
||||
- ``pl_module.online_network``
|
||||
- ``pl_module.target_network``
|
||||
- ``pl_module.global_step``
|
||||
|
||||
Updates the target_network params using an exponential moving average update rule weighted by tau.
|
||||
Tau parameter is increased from its base value to 1.0 with every training step scheduled with a cosine function.
|
||||
global_step correspond to the total number of sgd updates expected to happen throughout the BYOL training.
|
||||
|
||||
Target network is updated at the end of each SGD update on training batch.
|
||||
"""
|
||||
|
||||
def __init__(self, initial_tau: float = 0.99):
|
||||
"""
|
||||
Args:
|
||||
initial_tau: starting tau. Auto-updates with every training step
|
||||
"""
|
||||
super().__init__()
|
||||
self.initial_tau = initial_tau
|
||||
self.current_tau = initial_tau
|
||||
|
||||
def on_before_zero_grad(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: # type: ignore
|
||||
# get networks
|
||||
online_net = pl_module.online_network
|
||||
target_net = pl_module.target_network
|
||||
assert(isinstance(online_net, torch.nn.Module))
|
||||
assert(isinstance(target_net, torch.nn.Module))
|
||||
|
||||
# update weights
|
||||
self.update_weights(online_net, target_net)
|
||||
|
||||
# update tau after
|
||||
self.current_tau = self.update_tau(pl_module, trainer)
|
||||
|
||||
def update_tau(self, pl_module: pl.LightningModule, trainer: pl.Trainer) -> float:
|
||||
max_steps = len(trainer.train_dataloader) * trainer.max_epochs # type: ignore
|
||||
tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2
|
||||
return tau
|
||||
|
||||
def update_weights(self, online_net: torch.nn.Module, target_net: torch.nn.Module) -> None:
|
||||
# apply MA weight update
|
||||
for current_params, ma_params in zip(online_net.parameters(), target_net.parameters()):
|
||||
up_weight, old_weight = current_params.data, ma_params.data
|
||||
ma_params.data = old_weight * self.current_tau + (1 - self.current_tau) * up_weight
|
|
@ -0,0 +1,5 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
dataset:
|
||||
name: CIFAR10H
|
||||
num_workers: 6
|
||||
train:
|
||||
seed: 1
|
||||
batch_size: 512
|
||||
base_lr: 1e-3
|
||||
output_dir: cifar10h/self_supervised
|
||||
resume_from_last_checkpoint: False
|
||||
self_supervision:
|
||||
type: byol
|
||||
encoder_name: resnet50
|
||||
scheduler:
|
||||
epochs: 2500
|
||||
preprocess:
|
||||
resize: 32
|
|
@ -0,0 +1,16 @@
|
|||
dataset:
|
||||
name: CIFAR10H
|
||||
num_workers: 6
|
||||
train:
|
||||
seed: 1
|
||||
batch_size: 512
|
||||
base_lr: 1e-3
|
||||
output_dir: cifar10h/self_supervised
|
||||
resume_from_last_checkpoint: False
|
||||
self_supervision:
|
||||
type: simclr
|
||||
encoder_name: resnet101
|
||||
scheduler:
|
||||
epochs: 2000
|
||||
preprocess:
|
||||
resize: 32
|
|
@ -0,0 +1,49 @@
|
|||
dataset:
|
||||
name: NIH
|
||||
dataset_dir: datadrive/NIH
|
||||
train:
|
||||
seed: 3
|
||||
batch_size: 2400
|
||||
base_lr: 1e-3
|
||||
output_dir: nih/ssup
|
||||
resume_from_last_checkpoint: False
|
||||
self_supervision:
|
||||
type: byol
|
||||
encoder_name: resnet50
|
||||
use_balanced_binary_loss_for_linear_head: True
|
||||
scheduler:
|
||||
epochs: 1500
|
||||
preprocess:
|
||||
center_crop_size: 224
|
||||
resize: 256
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_random_affine: True
|
||||
use_random_color: True
|
||||
use_random_crop: True
|
||||
use_gamma_transform: True
|
||||
use_random_erasing: True
|
||||
add_gaussian_noise: True
|
||||
use_elastic_transform: True
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_erasing:
|
||||
scale: (0.15, 0.4)
|
||||
ratio: (0.33, 3)
|
||||
random_affine:
|
||||
max_angle: 180
|
||||
max_horizontal_shift: 0.00
|
||||
max_vertical_shift: 0.00
|
||||
max_shear: 40
|
||||
elastic_transform:
|
||||
sigma: 4
|
||||
alpha: 34
|
||||
p_apply: 0.4
|
||||
random_color:
|
||||
brightness: 0.2
|
||||
contrast: 0.2
|
||||
saturation: 0.0
|
||||
random_crop:
|
||||
scale: (0.4, 1.0)
|
||||
gaussian_noise:
|
||||
std: 0.05
|
|
@ -0,0 +1,46 @@
|
|||
dataset:
|
||||
name: NIH
|
||||
dataset_dir: datadrive/NIH
|
||||
train:
|
||||
seed: 2
|
||||
batch_size: 1600
|
||||
base_lr: 1e-3
|
||||
output_dir: nih/nohist
|
||||
resume_from_last_checkpoint: False
|
||||
self_supervision:
|
||||
type: simclr
|
||||
encoder_name: resnet50
|
||||
scheduler:
|
||||
epochs: 1500
|
||||
preprocess:
|
||||
center_crop_size: 224
|
||||
resize: 256
|
||||
augmentation:
|
||||
use_random_horizontal_flip: True
|
||||
use_random_affine: True
|
||||
use_random_color: True
|
||||
use_random_crop: True
|
||||
use_gamma_transform: False
|
||||
use_random_erasing: True
|
||||
add_gaussian_noise: True
|
||||
use_elastic_transform: True
|
||||
random_horizontal_flip:
|
||||
prob: 0.5
|
||||
random_erasing:
|
||||
scale: (0.005, 0.08)
|
||||
ratio: (0.33, 3)
|
||||
random_affine:
|
||||
max_angle: 180
|
||||
max_horizontal_shift: 0.00
|
||||
max_vertical_shift: 0.00
|
||||
max_shear: 40
|
||||
elastic_transform:
|
||||
sigma: 4
|
||||
alpha: 34
|
||||
p_apply: 0.4
|
||||
random_color:
|
||||
brightness: 0.3
|
||||
contrast: 0.3
|
||||
saturation: 0.0
|
||||
random_crop:
|
||||
scale: (0.5, 1.0)
|
|
@ -0,0 +1,85 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
|
||||
config = ConfigNode()
|
||||
|
||||
config.dataset = ConfigNode()
|
||||
config.dataset.name = 'CIFAR10'
|
||||
config.dataset.dataset_dir = ''
|
||||
config.dataset.num_workers = None
|
||||
|
||||
config.train = ConfigNode()
|
||||
config.train.seed = None
|
||||
config.train.batch_size = None
|
||||
config.train.output_dir = None
|
||||
config.train.resume_from_last_checkpoint = False
|
||||
config.train.base_lr = None
|
||||
config.train.self_supervision = ConfigNode()
|
||||
config.train.self_supervision.type = None
|
||||
config.train.self_supervision.encoder_name = None
|
||||
config.train.self_supervision.use_balanced_binary_loss_for_linear_head = False
|
||||
config.train.checkpoint_period = 200
|
||||
|
||||
config.scheduler = ConfigNode()
|
||||
config.scheduler.epochs = None
|
||||
|
||||
config.augmentation = ConfigNode()
|
||||
config.augmentation.use_random_crop = False
|
||||
config.augmentation.use_random_horizontal_flip = False
|
||||
config.augmentation.use_random_affine = False
|
||||
config.augmentation.use_label_smoothing = False
|
||||
config.augmentation.use_random_color = False
|
||||
config.augmentation.add_gaussian_noise = False
|
||||
config.augmentation.use_gamma_transform = False
|
||||
config.augmentation.use_random_erasing = False
|
||||
config.augmentation.use_elastic_transform = False
|
||||
config.augmentation.random_crop = ConfigNode()
|
||||
config.augmentation.random_crop.scale = (0.9, 1.0)
|
||||
|
||||
config.augmentation.elastic_transform = ConfigNode()
|
||||
config.augmentation.elastic_transform.sigma = 4
|
||||
config.augmentation.elastic_transform.alpha = 35
|
||||
config.augmentation.elastic_transform.p_apply = 0.5
|
||||
|
||||
config.augmentation.gaussian_noise = ConfigNode()
|
||||
config.augmentation.gaussian_noise.std = 0.01
|
||||
config.augmentation.gaussian_noise.p_apply = 0.5
|
||||
|
||||
config.augmentation.random_horizontal_flip = ConfigNode()
|
||||
config.augmentation.random_horizontal_flip.prob = 0.5
|
||||
|
||||
config.augmentation.random_affine = ConfigNode()
|
||||
config.augmentation.random_affine.max_angle = 0
|
||||
config.augmentation.random_affine.max_horizontal_shift = 0.0
|
||||
config.augmentation.random_affine.max_vertical_shift = 0.0
|
||||
config.augmentation.random_affine.max_shear = 5
|
||||
|
||||
config.augmentation.random_color = ConfigNode()
|
||||
config.augmentation.random_color.brightness = 0.0
|
||||
config.augmentation.random_color.contrast = 0.1
|
||||
config.augmentation.random_color.saturation = 0.1
|
||||
|
||||
config.augmentation.gamma = ConfigNode()
|
||||
config.augmentation.gamma.scale = (0.5, 1.5)
|
||||
|
||||
config.augmentation.label_smoothing = ConfigNode()
|
||||
config.augmentation.label_smoothing.epsilon = 0.1
|
||||
|
||||
config.augmentation.random_erasing = ConfigNode()
|
||||
config.augmentation.random_erasing.scale = (0.01, 0.1)
|
||||
config.augmentation.random_erasing.ratio = (0.3, 3.3)
|
||||
|
||||
config.preprocess = ConfigNode()
|
||||
config.preprocess.use_resize = False
|
||||
config.preprocess.use_center_crop = False
|
||||
config.preprocess.center_crop_size = 224
|
||||
config.preprocess.histogram_normalization = ConfigNode()
|
||||
config.preprocess.histogram_normalization.disk_size = 30
|
||||
config.preprocess.resize = 32
|
||||
|
||||
def get_default_model_config() -> ConfigNode:
|
||||
return config.clone()
|
|
@ -0,0 +1,156 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
from InnerEyeDataQuality.datasets.kaggle_cxr import KAGGLE_TOTAL_SIZE, KaggleCXR
|
||||
from InnerEyeDataQuality.datasets.nih_cxr import NIHCXR, NIH_TOTAL_SIZE
|
||||
from InnerEyeDataQuality.deep_learning.create_dataset_transforms import create_chest_xray_transform
|
||||
from InnerEyeDataQuality.deep_learning.dataloader import WorkerInitFunc
|
||||
from InnerEyeDataQuality.deep_learning.transforms import DualViewTransformWrapper
|
||||
import numpy as np
|
||||
|
||||
class KaggleDataModule(LightningDataModule):
|
||||
|
||||
def __init__(self, config: ConfigNode,
|
||||
num_devices: int,
|
||||
num_workers: int,
|
||||
*args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
This is the data module to load and prepare the Kaggle Pneumonia detection challenge dataset.
|
||||
:param config:
|
||||
:param num_devices: The number of GPUs to use. The total batch size specified in the config will be divided
|
||||
by the number of GPUs.
|
||||
:param num_workers: The number of dataloader workers.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = config
|
||||
self.num_samples = KAGGLE_TOTAL_SIZE
|
||||
self.batch_size = config.train.batch_size // num_devices
|
||||
self.num_workers = num_workers
|
||||
self.train_transforms = DualViewTransformWrapper(create_chest_xray_transform(self.config, is_train=True))
|
||||
self.val_transforms = DualViewTransformWrapper(create_chest_xray_transform(self.config, is_train=False))
|
||||
self.train_dataset = KaggleCXR(self.config.dataset.dataset_dir, use_training_split=True,
|
||||
transform=self.train_transforms, return_index=False)
|
||||
self.class_weights: Optional[torch.Tensor] = None
|
||||
if config.train.self_supervision.use_balanced_binary_loss_for_linear_head:
|
||||
# Weight = inverse class proportion.
|
||||
class_weights = len(self.train_dataset.targets) / np.bincount(self.train_dataset.targets)
|
||||
# Normalized class weights
|
||||
class_weights /= class_weights.sum()
|
||||
self.class_weights = torch.tensor(class_weights)
|
||||
|
||||
@property
|
||||
def num_classes(self) -> int:
|
||||
return 2
|
||||
|
||||
def train_dataloader(self) -> DataLoader: # type: ignore
|
||||
"""
|
||||
Returns Kaggle training set (80% of total dataset)
|
||||
"""
|
||||
return torch.utils.data.DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=False,
|
||||
worker_init_fn=WorkerInitFunc(self.config.train.seed),
|
||||
drop_last=True)
|
||||
|
||||
def val_dataloader(self) -> DataLoader: # type: ignore
|
||||
"""
|
||||
Returns Kaggle validation set (20% of total dataset)
|
||||
"""
|
||||
val_dataset = KaggleCXR(self.config.dataset.dataset_dir, use_training_split=False,
|
||||
transform=self.val_transforms, return_index=False)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=False,
|
||||
worker_init_fn=WorkerInitFunc(self.config.train.seed),
|
||||
drop_last=True)
|
||||
return loader
|
||||
|
||||
def test_dataloader(self) -> DataLoader: # type: ignore
|
||||
"""
|
||||
No Kaggle test split implemented
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def default_transforms(self) -> Callable:
|
||||
transform = create_chest_xray_transform(self.config, is_train=False)
|
||||
return transform
|
||||
|
||||
|
||||
class NIHDataModule(LightningDataModule):
|
||||
|
||||
def __init__(self, config: ConfigNode, num_devices: int, num_workers: int, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
This is the data module to load and prepare the NIH dataset (112k scans).
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = config
|
||||
self.batch_size = config.train.batch_size // num_devices
|
||||
self.num_workers = num_workers
|
||||
self.train_transforms = DualViewTransformWrapper(create_chest_xray_transform(self.config, is_train=True))
|
||||
self.val_transforms = DualViewTransformWrapper(create_chest_xray_transform(self.config, is_train=False))
|
||||
self.num_samples = NIH_TOTAL_SIZE
|
||||
self.train_dataset = NIHCXR(self.config.dataset.dataset_dir, use_training_split=True,
|
||||
transform=self.train_transforms, return_index=False)
|
||||
self.class_weights: Optional[torch.Tensor] = None
|
||||
if config.train.self_supervision.use_balanced_binary_loss_for_linear_head:
|
||||
# Weight = inverse class proportion.
|
||||
class_weights = len(self.train_dataset.targets) / np.bincount(self.train_dataset.targets)
|
||||
# Normalized class weights
|
||||
class_weights /= class_weights.sum()
|
||||
self.class_weights = torch.tensor(class_weights)
|
||||
|
||||
@property
|
||||
def num_classes(self) -> int:
|
||||
return 2
|
||||
|
||||
def train_dataloader(self) -> DataLoader: # type: ignore
|
||||
"""
|
||||
Returns NIH training set (80% of total dataset)
|
||||
"""
|
||||
return torch.utils.data.DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=False,
|
||||
worker_init_fn=WorkerInitFunc(self.config.train.seed),
|
||||
drop_last=True)
|
||||
|
||||
def val_dataloader(self) -> DataLoader: # type: ignore
|
||||
"""
|
||||
Returns NIH validation set (20% of total dataset)
|
||||
"""
|
||||
val_dataset = NIHCXR(self.config.dataset.dataset_dir,
|
||||
use_training_split=False, transform=self.val_transforms, return_index=False)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=False,
|
||||
worker_init_fn=WorkerInitFunc(self.config.train.seed),
|
||||
drop_last=True)
|
||||
return loader
|
||||
|
||||
def test_dataloader(self) -> DataLoader: # type: ignore
|
||||
"""
|
||||
No NIH test split implemented
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def default_transforms(self) -> Callable:
|
||||
transform = create_chest_xray_transform(self.config, is_train=False)
|
||||
return transform
|
|
@ -0,0 +1,82 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from InnerEyeDataQuality.datasets.cifar10h import TOTAL_CIFAR10H_DATASET_SIZE
|
||||
from pl_bolts.datamodules import CIFAR10DataModule
|
||||
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torchvision import transforms as transform_lib
|
||||
|
||||
|
||||
class CIFAR10HDataModule(CIFAR10DataModule):
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_samples = TOTAL_CIFAR10H_DATASET_SIZE
|
||||
self.class_weights = None
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
"""
|
||||
CIFAR train set removes a subset to use for validation
|
||||
"""
|
||||
transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms
|
||||
dataset = self.DATASET(self.data_dir, train=False, download=True, transform=transforms, **self.extra_args)
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.num_workers,
|
||||
drop_last=True,
|
||||
pin_memory=True)
|
||||
assert len(dataset) == TOTAL_CIFAR10H_DATASET_SIZE
|
||||
|
||||
return loader
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
"""
|
||||
CIFAR10 val set uses a subset of the training set for validation
|
||||
"""
|
||||
transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
|
||||
|
||||
dataset = self.DATASET(self.data_dir, train=True, download=False, transform=transforms, **self.extra_args)
|
||||
num_samples = len(dataset)
|
||||
_, dataset_val = random_split(dataset,
|
||||
[num_samples - self.val_split, self.val_split],
|
||||
generator=torch.Generator().manual_seed(self.seed))
|
||||
assert len(dataset_val) == self.val_split
|
||||
loader = DataLoader(dataset_val,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
return loader
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
"""
|
||||
CIFAR10 test set uses the test split
|
||||
"""
|
||||
transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms
|
||||
|
||||
dataset = self.DATASET(self.data_dir, train=True, download=False, transform=transforms, **self.extra_args)
|
||||
num_samples = len(dataset)
|
||||
dataset_test, _ = random_split(dataset,
|
||||
[num_samples - self.val_split, self.val_split],
|
||||
generator=torch.Generator().manual_seed(self.seed))
|
||||
loader = DataLoader(dataset_test,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
drop_last=True,
|
||||
pin_memory=True)
|
||||
return loader
|
||||
|
||||
def default_transforms(self) -> List[object]:
|
||||
cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
|
||||
return cf10_transforms
|
|
@ -0,0 +1,37 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import multiprocessing
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform
|
||||
|
||||
from .chestxray_datamodule import KaggleDataModule, NIHDataModule
|
||||
from .cifar10h_datamodule import CIFAR10HDataModule
|
||||
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
num_devices = num_gpus if num_gpus > 0 else 1
|
||||
|
||||
def create_ssl_data_modules(config: ConfigNode) -> pl.LightningDataModule:
|
||||
"""
|
||||
Returns torch lightining data module.
|
||||
"""
|
||||
num_workers = config.dataset.num_workers if config.dataset.num_workers else multiprocessing.cpu_count()
|
||||
|
||||
if config.dataset.name == "Kaggle":
|
||||
dm = KaggleDataModule(config, num_devices=num_devices, num_workers=num_workers) # type: ignore
|
||||
elif config.dataset.name == "NIH":
|
||||
dm = NIHDataModule(config, num_devices=num_devices, num_workers=num_workers) # type: ignore
|
||||
elif config.dataset.name == "CIFAR10H":
|
||||
dm = CIFAR10HDataModule(num_workers=num_workers,
|
||||
batch_size=config.train.batch_size // num_devices,
|
||||
seed=1234)
|
||||
dm.train_transforms = SimCLRTrainDataTransform(32)
|
||||
dm.val_transforms = SimCLREvalDataTransform(32)
|
||||
else:
|
||||
raise NotImplementedError(f"No pytorch data module implemented for dataset type: {config.dataset.name}")
|
||||
return dm
|
|
@ -0,0 +1,22 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
from torchvision.models import densenet121
|
||||
|
||||
|
||||
class DenseNet121Encoder(torch.nn.Module):
|
||||
"""
|
||||
This module creates a Densenet121 encoder i.e. Densenet121 model without
|
||||
its classification head.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.densenet121 = densenet121()
|
||||
self.cnn_model = self.densenet121.features
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.cnn_model(x)
|
|
@ -0,0 +1,118 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from default_paths import EXPERIMENT_DIR
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.byol.byol_module import BYOLInnerEye
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.datamodules.utils import create_ssl_data_modules
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.simclr_module import SimCLRInnerEye
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.ssl_classifier_module import (SSLOnlineEvaluatorInnerEye,
|
||||
get_encoder_output_dim)
|
||||
from InnerEyeDataQuality.deep_learning.utils import load_ssl_model_config
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
num_devices = num_gpus if num_gpus > 0 else 1
|
||||
trained_kwargs: Dict[str, Union[str, int, float]] = {"precision": 16} if num_gpus > 0 else {}
|
||||
if num_gpus > 1: # If multi-gpu training update the parameters for DDP
|
||||
trained_kwargs.update({"distributed_backend": "ddp", "sync_batchnorm": True})
|
||||
|
||||
def get_last_checkpoint_path(default_root_dir: str, model_version: str) -> str:
|
||||
return str(Path(default_root_dir) / model_version / "checkpoints" / "last.ckpt")
|
||||
|
||||
|
||||
def cli_main(config: ConfigNode, debug: bool = False) -> None:
|
||||
"""
|
||||
Runs self-supervised training on imaging data using contrastive loss.
|
||||
Currently it supports only ``BYOL`` and ``SimCLR``.
|
||||
|
||||
:param config: A ssl_model config specifying training configurations (and augmentations).
|
||||
Beware the augmentations parameters are ignored at the moment when using CIFAR10 as we
|
||||
always use the default augmentations from PyTorch Lightning.
|
||||
:param debug: If set to True, only runs training and validation on 1% of the data.
|
||||
"""
|
||||
|
||||
# set seed
|
||||
seed_everything(config.train.seed)
|
||||
|
||||
# self-supervision type
|
||||
ssl_type = config.train.self_supervision.type
|
||||
default_root_dir = EXPERIMENT_DIR / config.train.output_dir
|
||||
model_version = f'{ssl_type}_seed_{config.train.seed}'
|
||||
checkpoint_dir = str(default_root_dir / model_version / "checkpoints")
|
||||
|
||||
# Model checkpointing callback
|
||||
checkpoint_callback = ModelCheckpoint(period=config.train.checkpoint_period, save_top_k=-1, dirpath=checkpoint_dir)
|
||||
checkpoint_callback_last = ModelCheckpoint(save_last=True, dirpath=checkpoint_dir)
|
||||
|
||||
lr_logger = LearningRateMonitor()
|
||||
tb_logger = pl.loggers.TensorBoardLogger(save_dir=str(default_root_dir),
|
||||
version='logs',
|
||||
name=model_version)
|
||||
# Create SimCLR data modules and model
|
||||
dm = create_ssl_data_modules(config)
|
||||
if ssl_type == "simclr":
|
||||
model = SimCLRInnerEye(num_samples=dm.num_samples, # type: ignore
|
||||
batch_size=dm.batch_size, # type: ignore
|
||||
lr=config.train.base_lr,
|
||||
dataset_name=config.dataset.name,
|
||||
encoder_name=config.train.self_supervision.encoder_name)
|
||||
# Create BYOL model
|
||||
else:
|
||||
model = BYOLInnerEye(num_samples=dm.num_samples, # type: ignore
|
||||
learning_rate=config.train.base_lr,
|
||||
dataset_name=config.dataset.name,
|
||||
encoder_name=config.train.self_supervision.encoder_name,
|
||||
batch_size=dm.batch_size, # type: ignore
|
||||
warmup_epochs=10)
|
||||
model.hparams.update({'ssl_type': ssl_type})
|
||||
|
||||
# Online fine-tunning using an MLP
|
||||
online_eval = SSLOnlineEvaluatorInnerEye(class_weights=dm.class_weights, # type: ignore
|
||||
z_dim=get_encoder_output_dim(model, dm),
|
||||
num_classes=dm.num_classes, # type: ignore
|
||||
dataset=config.dataset.name,
|
||||
drop_p=0.2) # type: ignore
|
||||
|
||||
# Load latest checkpoint
|
||||
resume_from_last_checkpoint = get_last_checkpoint_path(default_root_dir, model_version) if \
|
||||
config.train.resume_from_last_checkpoint else None
|
||||
if debug:
|
||||
overfit_batches = num_devices / (min(len(dm.val_dataloader()), len(dm.train_dataloader())) * 2.0)
|
||||
trained_kwargs.update({"overfit_batches": overfit_batches})
|
||||
|
||||
# Create trainer and run training
|
||||
trainer = pl.Trainer(gpus=num_gpus,
|
||||
logger=tb_logger,
|
||||
default_root_dir=str(default_root_dir),
|
||||
benchmark=True,
|
||||
max_epochs=config.scheduler.epochs,
|
||||
callbacks=[lr_logger, online_eval, checkpoint_callback, checkpoint_callback_last],
|
||||
resume_from_checkpoint=resume_from_last_checkpoint,
|
||||
**trained_kwargs)
|
||||
trainer.fit(model, dm) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
parser = argparse.ArgumentParser(description='Train a self-supervised model')
|
||||
parser.add_argument('--config', dest='config', type=str, required=True,
|
||||
help='Path to config file characterising trained CNN model/s')
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
config_path = args.config
|
||||
config = load_ssl_model_config(config_path)
|
||||
# Launch the script
|
||||
cli_main(config)
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch_lightning.metrics import Metric
|
||||
from pytorch_lightning.metrics.functional import auroc
|
||||
|
||||
|
||||
class AreaUnderRocCurve(Metric):
|
||||
"""
|
||||
Computes the area under the receiver operating curve (ROC).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(dist_sync_on_step=False)
|
||||
self.add_state("preds", default=[], dist_reduce_fx=None)
|
||||
self.add_state("targets", default=[], dist_reduce_fx=None)
|
||||
self.name = "auc"
|
||||
|
||||
def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: # type: ignore
|
||||
assert preds.dim() == 2 and targets.dim() == 1 and \
|
||||
preds.shape[1] == 2 and preds.shape[0] == targets.shape[
|
||||
0], f"Expected 2-dim preds, 1-dim targets, but got: preds = {preds.shape}, targets = {targets.shape}"
|
||||
self.preds.append(preds) # type: ignore
|
||||
self.targets.append(targets) # type: ignore
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
"""
|
||||
Computes a metric from the stored predictions and targets.
|
||||
"""
|
||||
preds = torch.cat(self.preds) # type: ignore
|
||||
targets = torch.cat(self.targets) # type: ignore
|
||||
if torch.unique(targets).numel() == 1:
|
||||
return torch.tensor(np.nan)
|
||||
return auroc(preds[:, 1], targets)
|
|
@ -0,0 +1,43 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.byol.byol_models import SSLEncoder
|
||||
from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR
|
||||
|
||||
|
||||
class _Projection(nn.Module):
|
||||
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
self.input_dim = input_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(self.input_dim, self.hidden_dim, bias=True),
|
||||
nn.BatchNorm1d(self.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.hidden_dim, self.output_dim, bias=False))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.model(x)
|
||||
return F.normalize(x, dim=1)
|
||||
|
||||
|
||||
class SimCLRInnerEye(SimCLR):
|
||||
def __init__(self, encoder_name: str, dataset_name: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Args:
|
||||
encoder_name [str]: Image encoder name (predefined models)
|
||||
dataset_name [str]: Image dataset name (e.g. cifar10, kaggle, etc.)
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.save_hyperparameters()
|
||||
self.encoder = SSLEncoder(encoder_name, dataset_name, use_output_pooling=True)
|
||||
self.projection = _Projection(input_dim=self.encoder.get_output_feature_dim(), hidden_dim=2048, output_dim=128)
|
|
@ -0,0 +1,194 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
|
||||
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
|
||||
from torch import Tensor as T
|
||||
from torch.nn import ModuleList, functional as F
|
||||
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.metrics import AreaUnderRocCurve
|
||||
from pytorch_lightning.metrics import Accuracy
|
||||
|
||||
|
||||
BatchType = Tuple[List, T]
|
||||
|
||||
class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
|
||||
def __init__(self, class_weights: Optional[torch.Tensor] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Creates a hook to evaluate a linear model on top of an SSL embedding.
|
||||
|
||||
:param class_weights: The class weights to use when computing the cross entropy loss. If set to None,
|
||||
no weighting will be done.
|
||||
"""
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.training_step = int(0)
|
||||
self.weight_decay = 1e-4
|
||||
self.learning_rate = 1e-4
|
||||
|
||||
self.train_metrics = ModuleList([AreaUnderRocCurve(), Accuracy()]) if self.num_classes == 2 else ModuleList([Accuracy()])
|
||||
self.val_metrics = ModuleList([AreaUnderRocCurve(), Accuracy()]) if self.num_classes == 2 else ModuleList([Accuracy()])
|
||||
self.class_weights = class_weights
|
||||
|
||||
def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
||||
# Move metrics and class weights to module device
|
||||
for metric in [*self.train_metrics, *self.val_metrics]:
|
||||
metric.to(device=pl_module.device) # type: ignore
|
||||
if self.class_weights is not None:
|
||||
self.class_weights = self.class_weights.float().to(device=pl_module.device)
|
||||
|
||||
pl_module.non_linear_evaluator = SSLEvaluator(n_input=self.z_dim,
|
||||
n_classes=self.num_classes,
|
||||
p=self.drop_p,
|
||||
n_hidden=self.hidden_dim).to(pl_module.device)
|
||||
assert isinstance(pl_module.non_linear_evaluator, torch.nn.Module)
|
||||
self.optimizer = torch.optim.Adam(pl_module.non_linear_evaluator.parameters(),
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.weight_decay)
|
||||
|
||||
# Use only one of the transformed images
|
||||
@staticmethod
|
||||
def to_device(batch: BatchType, device: Union[str, torch.device]) -> Tuple[T, T]:
|
||||
(x1, x2), y = batch
|
||||
x1 = x1.to(device)
|
||||
y = y.to(device)
|
||||
return x1, y
|
||||
|
||||
def shared_step(self, batch: BatchType, pl_module: pl.LightningModule, is_training: bool) -> T:
|
||||
x, y = self.to_device(batch, pl_module.device)
|
||||
with torch.no_grad():
|
||||
representations = self.get_representations(pl_module, x)
|
||||
representations = representations.detach()
|
||||
assert isinstance(pl_module.non_linear_evaluator, torch.nn.Module)
|
||||
|
||||
# Run the linear-head with SSL embeddings.
|
||||
mlp_preds = pl_module.non_linear_evaluator(representations)
|
||||
mlp_loss = F.cross_entropy(mlp_preds, y, weight=self.class_weights)
|
||||
|
||||
with torch.no_grad():
|
||||
posteriors = F.softmax(mlp_preds, dim=-1)
|
||||
for metric in (self.train_metrics if is_training else self.val_metrics):
|
||||
metric(posteriors, y)
|
||||
|
||||
return mlp_loss
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): # type: ignore
|
||||
loss = self.shared_step(batch, pl_module, is_training=False)
|
||||
# Log classification metrics
|
||||
pl_module.log('ssl/online_val_loss', loss, on_step=False, on_epoch=True, sync_dist=False)
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: # type: ignore
|
||||
logger = trainer.logger.experiment # type: ignore
|
||||
loss = self.shared_step(batch, pl_module, is_training=True)
|
||||
|
||||
# update finetune weights
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.training_step += 1
|
||||
|
||||
# log metrics
|
||||
logger.add_scalar('ssl/online_train_loss', loss, global_step=self.training_step)
|
||||
|
||||
class SSLClassifier(torch.nn.Module):
|
||||
"""
|
||||
SSL Image classifier that combines pre-trained SSL encoder with a trainable linear-head.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, encoder: torch.nn.Module, projection: torch.nn.Module):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.projection = projection
|
||||
self.encoder.eval(), self.projection.eval()
|
||||
self.classifier_head = SSLEvaluator(n_input=get_encoder_output_dim(self.encoder),
|
||||
n_hidden=None,
|
||||
n_classes=num_classes,
|
||||
p=0.20)
|
||||
|
||||
def train(self, mode: bool = True) -> Any:
|
||||
self.training = mode
|
||||
self.encoder.train(False)
|
||||
self.projection.train(False)
|
||||
self.classifier_head.train(mode)
|
||||
return self
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert isinstance(self.encoder.avgpool, torch.nn.Module)
|
||||
with torch.no_grad():
|
||||
# Generate representations
|
||||
repr = self.encoder(x)
|
||||
|
||||
# Generate image embeddings
|
||||
self.projection(repr)
|
||||
# Generate class logits
|
||||
agg_repr = self.encoder.avgpool(repr) if repr.ndim > 2 else repr
|
||||
agg_repr = agg_repr.reshape(agg_repr.size(0), -1).detach()
|
||||
|
||||
return self.classifier_head(agg_repr)
|
||||
|
||||
class PretrainedClassifier(torch.nn.Module):
|
||||
def __init__(self, num_classes: int, encoder: torch.nn.Module):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.classifier_head = torch.nn.Linear(get_encoder_output_dim(self.encoder), num_classes)
|
||||
|
||||
def train(self, mode: bool = True) -> Any:
|
||||
self.classifier_head.train(mode)
|
||||
self.encoder.train(mode)
|
||||
return self
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Generate representations
|
||||
repr = self.encoder(x)
|
||||
# Generate class logits
|
||||
agg_repr = self.encoder.avgpool(repr) if repr.ndim > 2 else repr # type: ignore
|
||||
agg_repr = agg_repr.reshape(agg_repr.size(0), -1)
|
||||
return self.classifier_head(agg_repr)
|
||||
|
||||
def get_encoder_output_dim(pl_module: Union[pl.LightningModule, torch.nn.Module],
|
||||
dm: Optional[pl.LightningDataModule] = None) -> int:
|
||||
"""
|
||||
Calculates the output dimension of ssl encoder by making a single forward pass.
|
||||
:param pl_module: pl encoder module
|
||||
:param dm: pl datamodule
|
||||
"""
|
||||
# Target device
|
||||
device = pl_module.device if isinstance(pl_module, pl.LightningDataModule) else \
|
||||
next(pl_module.parameters()).device # type: ignore
|
||||
assert (isinstance(device, torch.device))
|
||||
|
||||
# Create a dummy input image
|
||||
if dm is not None:
|
||||
batch = iter(dm.train_dataloader()).next() # type: ignore
|
||||
x, _ = SSLOnlineEvaluatorInnerEye.to_device(batch, device)
|
||||
else:
|
||||
x = torch.rand((1, 3, 256, 256)).to(device)
|
||||
|
||||
# Extract the number of output feature dimensions
|
||||
with torch.no_grad():
|
||||
representations = pl_module(x)
|
||||
|
||||
return representations.shape[1]
|
||||
|
||||
|
||||
def WrapSSL(ssl_class: Any, num_classes: int) -> Any:
|
||||
"""
|
||||
Wraps a given SSL encoder and adds a non-linear evaluator to it. This is done to load pre-trained SSL checkpoints.
|
||||
PL requires non_linear_evaluator to be included in pl_module at SSL training time.
|
||||
:param num_classes: Number of target classes for the linear head.
|
||||
:param ssl_class: SSL object either BYOL or SimCLR.
|
||||
"""
|
||||
class _wrap(ssl_class): # type: ignore
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.non_linear_evaluator = SSLEvaluator(n_input=get_encoder_output_dim(self),
|
||||
n_classes=num_classes,
|
||||
n_hidden=None)
|
||||
|
||||
return _wrap
|
|
@ -0,0 +1,63 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.ssl_classifier_module import PretrainedClassifier, SSLClassifier, \
|
||||
WrapSSL
|
||||
from pl_bolts.models.self_supervised.resnets import resnet18, resnet50_bn, resnet101
|
||||
|
||||
|
||||
def create_ssl_encoder(encoder_name: str, dataset_name: Optional[str] = None) -> torch.nn.Module:
|
||||
"""
|
||||
"""
|
||||
if encoder_name == 'resnet18':
|
||||
encoder = resnet18(return_all_feature_maps=False)
|
||||
elif encoder_name == 'resnet50':
|
||||
encoder = resnet50_bn(return_all_feature_maps=False)
|
||||
elif encoder_name == 'resnet101':
|
||||
encoder = resnet101(return_all_feature_maps=False)
|
||||
else:
|
||||
raise ValueError("Unknown model type")
|
||||
|
||||
if dataset_name is not None:
|
||||
if dataset_name in ["CIFAR10", "CIFAR10H"]:
|
||||
logging.info("Updating the initial convolution in order not to shrink CIFAR10 images")
|
||||
encoder.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def create_ssl_image_classifier(num_classes: int, pl_checkpoint_path: str, freeze_encoder: bool = True) -> torch.nn.Module:
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.byol.byol_module import BYOLInnerEye
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.simclr_module import SimCLRInnerEye
|
||||
"""
|
||||
"""
|
||||
ssl_type = torch.load(pl_checkpoint_path, map_location=lambda storage, loc: storage)["hyper_parameters"]["ssl_type"]
|
||||
logging.info(f"Creating a {ssl_type} based image classifier")
|
||||
logging.info(f"Loading pretrained {ssl_type} weights from:\n {pl_checkpoint_path}")
|
||||
|
||||
if ssl_type == "byol":
|
||||
byol_module = WrapSSL(BYOLInnerEye, num_classes).load_from_checkpoint(pl_checkpoint_path, strict=False)
|
||||
if freeze_encoder:
|
||||
model = SSLClassifier(num_classes=num_classes, encoder=byol_module.target_network.encoder,
|
||||
projection=byol_module.target_network.projector_normalised)
|
||||
else:
|
||||
model = PretrainedClassifier(num_classes=num_classes, # type: ignore
|
||||
encoder=byol_module.target_network.encoder)
|
||||
elif ssl_type == "simclr":
|
||||
simclr_module = WrapSSL(SimCLRInnerEye, num_classes).load_from_checkpoint(pl_checkpoint_path, strict=False)
|
||||
if freeze_encoder:
|
||||
model = SSLClassifier(num_classes=num_classes, encoder=simclr_module.encoder,
|
||||
projection=simclr_module.projection)
|
||||
else:
|
||||
model = PretrainedClassifier(num_classes=num_classes, # type: ignore
|
||||
encoder=simclr_module.encoder)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown unsupervised model: {ssl_type}")
|
||||
|
||||
return model
|
|
@ -0,0 +1,51 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
from InnerEyeDataQuality.deep_learning.utils import create_logger, get_run_config, load_model_config
|
||||
from InnerEyeDataQuality.deep_learning.trainers.co_teaching_trainer import CoTeachingTrainer
|
||||
from InnerEyeDataQuality.deep_learning.trainers.elr_trainer import ELRTrainer
|
||||
from InnerEyeDataQuality.deep_learning.trainers.vanilla_trainer import VanillaTrainer
|
||||
from InnerEyeDataQuality.utils.generic import set_seed
|
||||
|
||||
|
||||
def train(config: ConfigNode) -> None:
|
||||
create_logger(config.train.output_dir)
|
||||
logging.info('Starting training...')
|
||||
if config.train.use_co_teaching and config.train.use_elr:
|
||||
raise ValueError("You asked for co-teaching and ELR at the same time. Please double check your configuration.")
|
||||
if config.train.use_co_teaching:
|
||||
model_trainer_class = CoTeachingTrainer
|
||||
elif config.train.use_elr:
|
||||
model_trainer_class = ELRTrainer # type: ignore
|
||||
else:
|
||||
model_trainer_class = VanillaTrainer # type: ignore
|
||||
model_trainer_class(config).run_training()
|
||||
|
||||
|
||||
def train_ensemble(config: ConfigNode, num_runs: int) -> None:
|
||||
for i, _ in enumerate(range(num_runs)):
|
||||
config_run = get_run_config(config, config.train.seed + i)
|
||||
set_seed(config_run.train.seed)
|
||||
os.makedirs(config_run.train.output_dir, exist_ok=True)
|
||||
train(config_run)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Parser for model training.')
|
||||
parser.add_argument('--config', type=str, required=True,
|
||||
help='Path to config file characterising trained CNN model/s')
|
||||
parser.add_argument('--num_runs', type=int, default=1, help='Number of runs (ensemble)')
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
|
||||
# Load config
|
||||
config = load_model_config(args.config)
|
||||
|
||||
# Launch training
|
||||
train_ensemble(config, args.num_runs)
|
|
@ -0,0 +1,5 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
from InnerEyeDataQuality.deep_learning.architectures.ema import EMA
|
||||
from InnerEyeDataQuality.deep_learning.collect_embeddings import get_all_embeddings
|
||||
from InnerEyeDataQuality.deep_learning.graph.classifier import GraphClassifier
|
||||
from InnerEyeDataQuality.deep_learning.loss import CrossEntropyLoss, consistency_loss
|
||||
from InnerEyeDataQuality.deep_learning.trainers.model_trainer_base import IndexContainer, Loss, ModelTrainer
|
||||
from InnerEyeDataQuality.deep_learning.scheduler import ForgetRateScheduler
|
||||
from InnerEyeDataQuality.deep_learning.utils import create_model
|
||||
from InnerEyeDataQuality.utils.generic import find_union_set_torch, map_to_device
|
||||
|
||||
|
||||
class CoTeachingTrainer(ModelTrainer):
|
||||
"""
|
||||
Implements co-teaching training using two models
|
||||
"""
|
||||
|
||||
def __init__(self, config: ConfigNode):
|
||||
self.use_teacher_model = config.train.use_teacher_model
|
||||
super().__init__(config)
|
||||
|
||||
self.forget_rate_scheduler = ForgetRateScheduler(
|
||||
config.scheduler.epochs,
|
||||
forget_rate=config.train.co_teaching_forget_rate,
|
||||
num_gradual=config.train.co_teaching_num_gradual,
|
||||
start_epoch=config.train.resume_epoch if config.train.resume_epoch > 0 else 0,
|
||||
num_warmup_epochs=config.train.co_teaching_num_warmup)
|
||||
|
||||
self.joint_metric_tracker = self.train_trackers[0]
|
||||
self.use_consistency_loss = config.train.co_teaching_consistency_loss
|
||||
self.use_graph = config.train.co_teaching_use_graph
|
||||
self.consistency_loss_weight = 0.10
|
||||
self.num_models = len(self.models)
|
||||
self.loss_fn = CrossEntropyLoss(config)
|
||||
self.ema_models = [EMA(self.models[0]), EMA(self.models[1])] if config.train.use_teacher_model else None
|
||||
self.graph_classifiers = [GraphClassifier(num_samples=len(self.train_loader.dataset), # type: ignore
|
||||
num_classes=config.dataset.n_classes,
|
||||
labels=self.train_loader.dataset.targets, # type: ignore
|
||||
device=config.device)
|
||||
for _ in range(2)]
|
||||
|
||||
# Create two models for co-teaching
|
||||
def get_models(self, config: ConfigNode) -> List[torch.nn.Module]:
|
||||
"""
|
||||
:param config: The job config
|
||||
:return: A list of two models to be trained
|
||||
"""
|
||||
return [create_model(config, model_id=0), create_model(config, model_id=1)]
|
||||
|
||||
def update_teacher_models(self) -> None:
|
||||
if self.ema_models:
|
||||
for i in range(len(self.models)):
|
||||
self.ema_models[i].update()
|
||||
|
||||
def deploy_teacher_models(self, inputs: torch.Tensor) -> Optional[List[torch.Tensor]]:
|
||||
if not self.ema_models:
|
||||
return None
|
||||
elif isinstance(inputs, list):
|
||||
return [_m.inference(_i) for _m, _i in zip(self.ema_models, inputs)]
|
||||
else:
|
||||
return [_m.inference(inputs) for _m in self.ema_models]
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_samples_for_update(self,
|
||||
outputs: List[torch.Tensor],
|
||||
labels: torch.Tensor,
|
||||
global_indices: torch.Tensor,
|
||||
teacher_logits: Optional[List[torch.Tensor]]) -> Tuple[IndexContainer, IndexContainer]:
|
||||
"""
|
||||
Return a list of indices that should be kept for gradient updates.
|
||||
"""
|
||||
|
||||
def _get_small_loss_sample_ids(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_samples = labels.shape[0]
|
||||
num_remember = int((1. - self.forget_rate_scheduler.get_forget_rate) * num_samples)
|
||||
per_sample_loss = self.loss_fn(logits, labels, reduction='none')
|
||||
ind_sorted = torch.argsort(per_sample_loss)
|
||||
return ind_sorted[:num_remember], ind_sorted[num_remember:]
|
||||
|
||||
judge = lambda i: teacher_logits[i] if self.use_teacher_model and teacher_logits else outputs[i]
|
||||
ind_1_keep, ind_1_exc = _get_small_loss_sample_ids(judge(0))
|
||||
ind_0_keep, ind_0_exc = _get_small_loss_sample_ids(judge(1))
|
||||
|
||||
# Use graph based classifier
|
||||
if self.graph_classifiers[0].graph is not None:
|
||||
ind_0_keep, ind_0_exc = self.graph_classifiers[0].filter_cases(ind_0_keep, ind_0_exc, global_indices)
|
||||
ind_1_keep, ind_1_exc = self.graph_classifiers[1].filter_cases(ind_1_keep, ind_1_exc, global_indices)
|
||||
|
||||
return IndexContainer(ind_0_keep, ind_0_exc), IndexContainer(ind_1_keep, ind_1_exc)
|
||||
|
||||
# Co-teaching loss
|
||||
def compute_loss(self,
|
||||
outputs: List[torch.Tensor],
|
||||
labels: torch.Tensor,
|
||||
indices: Optional[Tuple[IndexContainer, IndexContainer]] = None,
|
||||
**kwargs: Any) -> List[Loss]:
|
||||
"""
|
||||
Implements the co-teaching loss using the outputs of two different models
|
||||
:param outputs: A list of logits outputed by each model
|
||||
:param labels: The target labels
|
||||
:param indices: Sample indices that should be kept and excluded in loss computation (for both models).
|
||||
:return: A list of Loss object, each element contains the loss that is fed to the optimizer and a
|
||||
tensor of per sample losses
|
||||
"""
|
||||
options = {'ema_logits': None}
|
||||
options.update(kwargs)
|
||||
ema_logits: Optional[List[torch.Tensor]] = options['ema_logits']
|
||||
|
||||
loss_obj = list()
|
||||
indices: Tuple[IndexContainer, IndexContainer] = [ # type: ignore
|
||||
IndexContainer(keep=torch.arange(labels.shape[0]), exclude=torch.tensor([], dtype=torch.long)) for _ in
|
||||
range(len(outputs))] if indices is None else indices
|
||||
assert indices is not None # for mypy
|
||||
assert len(outputs) == len(indices) == 2
|
||||
for _output, _index in zip(outputs, indices):
|
||||
# The indices to keep for each model is determined by the loss of the other.
|
||||
per_sample_loss = self.loss_fn(_output, labels, reduction='none')
|
||||
if self.weight is not None:
|
||||
per_sample_loss *= self.weight[labels]
|
||||
loss_update = torch.mean(per_sample_loss[_index.keep])
|
||||
loss_obj.append(Loss(per_sample_loss, loss_update))
|
||||
|
||||
# Consistency loss between predictions on noisy samples
|
||||
if self.use_consistency_loss and ema_logits:
|
||||
joint_excluded = find_union_set_torch(indices[0].exclude, indices[1].exclude)
|
||||
c_loss0 = consistency_loss(outputs[0][joint_excluded], ema_logits[0][joint_excluded])
|
||||
c_loss1 = consistency_loss(outputs[1][joint_excluded], ema_logits[1][joint_excluded])
|
||||
loss_obj[0].loss += self.consistency_loss_weight * c_loss0
|
||||
loss_obj[1].loss += self.consistency_loss_weight * c_loss1
|
||||
|
||||
return loss_obj
|
||||
|
||||
def run_epoch(self, dataloader: DataLoader, epoch: int, is_train: bool = False) -> None:
|
||||
"""
|
||||
Run a training or validation epoch of the base model trainer but also step the forget rate scheduler
|
||||
:param dataloader: A dataloader object.
|
||||
:param epoch: Current epoch id.
|
||||
:param is_train: Whether this is a training epoch or not.
|
||||
:param run_inference_on_training_set: If True, record all metrics using the train_trackers
|
||||
(even if is_train = False)
|
||||
:return:
|
||||
"""
|
||||
for model in self.models:
|
||||
model.train() if is_train else model.eval()
|
||||
trackers = self.train_trackers if is_train else self.val_trackers
|
||||
# Consume input dataloader and update model
|
||||
for indices, images, labels in dataloader:
|
||||
images, labels = map_to_device(images, self.device), labels.to(self.device)
|
||||
outputs = self.forward(images, requires_grad=is_train)
|
||||
ema_logits = self.deploy_teacher_models(images)
|
||||
selected_ind = self._get_samples_for_update(outputs, labels, indices, ema_logits) if is_train else None
|
||||
losses = self.compute_loss(outputs, labels, selected_ind, ema_logits=ema_logits)
|
||||
assert (len(outputs) == len(losses)) & (len(outputs) == 2)
|
||||
|
||||
if is_train:
|
||||
assert selected_ind is not None
|
||||
self.step_optimizers(losses)
|
||||
self.update_teacher_models()
|
||||
self.joint_metric_tracker.append_batch_aggregate(epoch=epoch,
|
||||
logits_x=outputs[0].detach(),
|
||||
logits_y=outputs[1].detach(),
|
||||
dropped_cases=indices[selected_ind[0].exclude],
|
||||
indices=indices)
|
||||
|
||||
# Collect model embeddings
|
||||
embeddings = get_all_embeddings(self.all_model_cnn_embeddings)
|
||||
# Log training and validation stats in metric tracker
|
||||
for i, (logits, loss) in enumerate(zip(outputs, losses)):
|
||||
teacher_logits = ema_logits[i].detach() if self.ema_models else None # type: ignore
|
||||
trackers[i].sample_metrics.append_batch(epoch=epoch,
|
||||
logits=logits.detach(),
|
||||
labels=labels.detach(),
|
||||
loss=loss.loss.item(),
|
||||
indices=indices.tolist(),
|
||||
per_sample_loss=loss.per_sample_loss.detach(),
|
||||
embeddings=embeddings[i],
|
||||
teacher_logits=teacher_logits)
|
||||
|
||||
# Adjust forget rate for co-teaching
|
||||
if is_train:
|
||||
self.forget_rate_scheduler.step()
|
||||
if self.use_graph:
|
||||
self.graph_classifiers[0].build_graph(embeddings=trackers[1].sample_metrics.embeddings_per_sample)
|
||||
self.graph_classifiers[1].build_graph(embeddings=trackers[0].sample_metrics.embeddings_per_sample)
|
||||
|
||||
def save_checkpoint(self, epoch: int) -> bool:
|
||||
is_save = super().save_checkpoint(epoch=epoch)
|
||||
if is_save and self.ema_models:
|
||||
is_last_epoch = epoch == self.config.scheduler.epochs - 1
|
||||
suffix = '_last_epoch.pt' if is_last_epoch else f'_epoch_{epoch:d}.pt'
|
||||
for i, ema_model in enumerate(self.ema_models):
|
||||
save_path = str(self.checkpoint_dir / f'checkpoint_ema_model_{i:d}') + suffix
|
||||
ema_model.save_model(save_path)
|
||||
return is_save
|
|
@ -0,0 +1,96 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.types import Device
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
from InnerEyeDataQuality.deep_learning.collect_embeddings import get_all_embeddings
|
||||
from InnerEyeDataQuality.deep_learning.loss import CrossEntropyLoss
|
||||
from InnerEyeDataQuality.deep_learning.trainers.model_trainer_base import Loss
|
||||
from InnerEyeDataQuality.deep_learning.trainers.vanilla_trainer import VanillaTrainer
|
||||
|
||||
|
||||
class ELRTrainer(VanillaTrainer):
|
||||
def __init__(self, config: ConfigNode):
|
||||
super().__init__(config)
|
||||
self.num_classes = config.dataset.n_classes
|
||||
self.loss_fn = {"TRAIN": ELRLoss(num_examples=self.train_trackers[0].num_samples_total,
|
||||
num_classes=config.dataset.n_classes,
|
||||
device=self.device),
|
||||
"VAL": CrossEntropyLoss(config)}
|
||||
|
||||
def compute_loss(self, is_train: bool, outputs: List[torch.Tensor], labels: torch.Tensor,
|
||||
indices: torch.Tensor = None) -> Loss:
|
||||
"""
|
||||
Implements the standard cross-entropy loss using one model
|
||||
:param outputs: A list of logits outputed by each model
|
||||
:param labels: The target labels
|
||||
:return: A list of Loss object, each element contains the loss that is fed to the optimizer and a
|
||||
tensor of per sample losses
|
||||
"""
|
||||
logits = outputs[0]
|
||||
if is_train:
|
||||
per_sample_loss = self.loss_fn["TRAIN"](predictions=logits, targets=labels, indices=indices) # type: ignore
|
||||
else:
|
||||
per_sample_loss = self.loss_fn["VAL"](predictions=logits, targets=labels, reduction='none') # type: ignore
|
||||
loss = torch.mean(per_sample_loss)
|
||||
return Loss(per_sample_loss, loss)
|
||||
|
||||
def run_epoch(self, dataloader: DataLoader, epoch: int, is_train: bool = False) -> None:
|
||||
"""
|
||||
Run a training or validation epoch of the base model trainer but also step the forget rate scheduler
|
||||
:param dataloader: A dataloader object
|
||||
:param epoch: Current epoch id.
|
||||
:param is_train: Whether this is a training epoch or not
|
||||
:param run_inference_on_training_set: If True, record all metrics using the train_trackers
|
||||
(even if is_train = False)
|
||||
:return:
|
||||
"""
|
||||
|
||||
for model in self.models:
|
||||
model.train() if is_train else model.eval()
|
||||
|
||||
for indices, images, labels in dataloader:
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
outputs = self.forward(images, requires_grad=is_train)
|
||||
embeddings = get_all_embeddings(self.all_model_cnn_embeddings)[0]
|
||||
losses = self.compute_loss(is_train, outputs, labels, indices)
|
||||
if is_train:
|
||||
self.step_optimizers([losses])
|
||||
|
||||
# Log training and validation stats in metric tracker
|
||||
tracker = self.train_trackers[0] if is_train else self.val_trackers[0]
|
||||
tracker.sample_metrics.append_batch(epoch, outputs[0].detach(), labels.detach(), losses.loss.item(),
|
||||
indices.cpu().tolist(), losses.per_sample_loss.detach(), embeddings)
|
||||
|
||||
|
||||
class ELRLoss(nn.Module):
|
||||
"""
|
||||
Adapted from https://github.com/shengliu66/ELR.
|
||||
"""
|
||||
|
||||
def __init__(self, num_examples: int, num_classes: int, device: Device, beta: float = 0.9, _lambda: float = 3):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.targets = torch.zeros(num_examples, self.num_classes, device=device)
|
||||
self.beta = beta
|
||||
self._lambda = _lambda
|
||||
|
||||
def forward(self, predictions: torch.Tensor, targets: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
y_pred = F.softmax(predictions, dim=1)
|
||||
y_pred = torch.clamp(y_pred, 1e-4, 1.0 - 1e-4)
|
||||
y_pred_ = y_pred.data.detach()
|
||||
self.targets[indices] = self.beta * self.targets[indices] + (1 - self.beta) * (
|
||||
(y_pred_) / (y_pred_).sum(dim=1, keepdim=True))
|
||||
ce_loss = F.cross_entropy(predictions, targets, reduction="none")
|
||||
elr_reg = (1 - (self.targets[indices] * y_pred).sum(dim=1)).log()
|
||||
per_sample_loss = ce_loss + self._lambda * elr_reg
|
||||
return per_sample_loss
|
|
@ -0,0 +1,284 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
from InnerEyeDataQuality.deep_learning.architectures.ema import EMA
|
||||
from InnerEyeDataQuality.deep_learning.collect_embeddings import register_embeddings_collector
|
||||
from InnerEyeDataQuality.deep_learning.dataloader import (get_number_of_samples_per_epoch, get_train_dataloader,
|
||||
get_val_dataloader)
|
||||
from InnerEyeDataQuality.deep_learning.metrics.tracker import MetricTracker
|
||||
from InnerEyeDataQuality.utils.dataset_utils import get_datasets
|
||||
from PyTorchImageClassification.optim import create_optimizer
|
||||
from PyTorchImageClassification.scheduler import create_scheduler
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IndexContainer:
|
||||
keep: torch.Tensor
|
||||
exclude: torch.Tensor
|
||||
|
||||
@dataclass
|
||||
class Loss:
|
||||
per_sample_loss: torch.Tensor # the loss value computed for each sample without any modifications
|
||||
loss: torch.Tensor # the loss that is fed to the optimizer; can be derived from per_sample loss given some rule
|
||||
|
||||
|
||||
class ModelTrainer(object):
|
||||
"""
|
||||
ModelTrainer class handles the training of models, logging, saving checkpoints and validation. This class trains one
|
||||
or more similar models at a time each having its own optimizer but which can interact in the loss function.
|
||||
This is an abstract class for which some methods are left not implemented;
|
||||
child classes must implement these methods
|
||||
"""
|
||||
|
||||
def __init__(self, config: ConfigNode) -> None:
|
||||
self.config = config
|
||||
self.checkpoint_dir = Path(self.config.train.output_dir) / 'checkpoints'
|
||||
self.log_dir = Path(self.config.train.output_dir) / 'logs'
|
||||
self.seed = config.train.seed
|
||||
self.device = torch.device(config.device)
|
||||
train_dataset, val_dataset = get_datasets(config)
|
||||
self.weight = torch.tensor([train_dataset.weight, (1-train_dataset.weight)], # type: ignore
|
||||
device=self.device, dtype=torch.float) if hasattr(train_dataset, "weight") else None
|
||||
self.train_loader = get_train_dataloader(train_dataset, config, seed=self.seed,
|
||||
drop_last=config.train.dataloader.drop_last, shuffle=True)
|
||||
self.val_loader = get_val_dataloader(val_dataset, config, seed=self.seed)
|
||||
self.models = self.get_models(config)
|
||||
self.ema_models: Optional[List[EMA]] = None
|
||||
self.schedulers = [create_scheduler(config, create_optimizer(config, model), len(self.train_loader))
|
||||
for model in self.models]
|
||||
self.train_trackers, self.val_trackers = self._create_metric_trackers(config)
|
||||
self.all_trackers = self.train_trackers + self.val_trackers
|
||||
self.all_model_cnn_embeddings = register_embeddings_collector(self.models, use_only_in_train=True)
|
||||
|
||||
def _create_metric_trackers(self, config: ConfigNode) -> Tuple[List[MetricTracker], List[MetricTracker]]:
|
||||
"""
|
||||
Creates metric trackers used at model training and validation.
|
||||
"""
|
||||
train_loader = self.train_loader
|
||||
val_loader = self.val_loader
|
||||
num_models = len(self.models)
|
||||
|
||||
if hasattr(train_loader.dataset, "ambiguity_metric_args"):
|
||||
ambiguity_metric_args = train_loader.dataset.ambiguity_metric_args # type: ignore
|
||||
else:
|
||||
ambiguity_metric_args = dict()
|
||||
|
||||
save_tf_events = config.tensorboard.save_events if hasattr(config.tensorboard, 'save_events') else True
|
||||
metric_kwargs = {"num_epochs": config.scheduler.epochs,
|
||||
"num_classes": config.dataset.n_classes,
|
||||
"save_tf_events": save_tf_events}
|
||||
# A dataset without augmentations and not normalized
|
||||
dataset_train, dataset_val = get_datasets(config)
|
||||
train_trackers = [MetricTracker(dataset=dataset_train,
|
||||
output_dir=str(self.log_dir / f'model_{i:d}_train'),
|
||||
num_samples_total=len(train_loader.dataset), # type: ignore
|
||||
num_samples_per_epoch=get_number_of_samples_per_epoch(train_loader),
|
||||
name=f"model_{i}_train",
|
||||
**{**metric_kwargs, **ambiguity_metric_args})
|
||||
for i in range(num_models)]
|
||||
val_trackers = [MetricTracker(dataset=dataset_val,
|
||||
output_dir=str(self.log_dir / f'model_{i:d}_val'),
|
||||
num_samples_total=len(val_loader.dataset), # type: ignore
|
||||
num_samples_per_epoch=get_number_of_samples_per_epoch(val_loader),
|
||||
name=f"model_{i}_valid",
|
||||
**metric_kwargs) for i in range(num_models)]
|
||||
|
||||
return train_trackers, val_trackers
|
||||
|
||||
def get_models(self, config: ConfigNode) -> List[torch.nn.Module]:
|
||||
"""
|
||||
:param config: The job config
|
||||
:return: A list of models to be trained
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, images: torch.Tensor, requires_grad: bool = True) -> List[torch.Tensor]:
|
||||
"""
|
||||
Performs the forward pass for all models in the ModelTrainer class
|
||||
:param images: The input images
|
||||
:param requires_grad: Flag to indicate if forward pass required .grad attributes
|
||||
:return: A list of the logits for each model
|
||||
"""
|
||||
def _forward(inputs: Union[List[torch.Tensor], torch.Tensor]) -> List[torch.Tensor]:
|
||||
if isinstance(inputs, list):
|
||||
return [_model(_input) for _model, _input in zip(self.models, inputs)]
|
||||
else:
|
||||
return [_model(inputs) for _model in self.models]
|
||||
|
||||
@torch.no_grad()
|
||||
def _forward_inference_only(inputs: Union[List[torch.Tensor], torch.Tensor]) -> List[torch.Tensor]:
|
||||
return _forward(inputs)
|
||||
|
||||
if requires_grad:
|
||||
return _forward(images)
|
||||
else:
|
||||
return _forward_inference_only(images)
|
||||
|
||||
def compute_loss(self, outputs: List[torch.Tensor], labels: torch.Tensor,
|
||||
indices: Optional[Tuple[IndexContainer, IndexContainer]] = None) -> Union[List[Loss], Loss]:
|
||||
"""
|
||||
Compute the losses that will be optimized
|
||||
:param outputs: A list of logits outputed by each model
|
||||
:param labels: The target labels
|
||||
:return: A list of Loss object, each element contains the loss that is fed to the optimizer and a
|
||||
tensor of per sample losses
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def step_optimizers(self, losses: List[Loss]) -> None:
|
||||
"""
|
||||
Take an optimizer step for every model's optimizer
|
||||
:param losses: A list of Loss objects
|
||||
:return:
|
||||
"""
|
||||
for loss, scheduler in zip(losses, self.schedulers):
|
||||
scheduler.optimizer.zero_grad()
|
||||
loss.loss.backward()
|
||||
scheduler.optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
def run_epoch(self, dataloader: DataLoader, epoch: int, is_train: bool = False) -> None:
|
||||
"""
|
||||
Run a training or validation epoch
|
||||
:param dataloader: A dataloader object
|
||||
:param epoch: Current training epoch id
|
||||
:param is_train: Whether this is a training epoch or not
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save_checkpoint(self, epoch: int) -> bool:
|
||||
"""
|
||||
Save checkpoints for the models for the current epoch
|
||||
:param epoch: The current epoch
|
||||
:return: Save success
|
||||
"""
|
||||
is_last_epoch = epoch == self.config.scheduler.epochs - 1
|
||||
is_save = is_last_epoch or (epoch > 0 and epoch % self.config.train.checkpoint_period == 0)
|
||||
|
||||
if is_save:
|
||||
logging.info(f"Saving model checkpoints, epoch {epoch}")
|
||||
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
for ii in range(len(self.models)):
|
||||
path = str(self.checkpoint_dir / f'checkpoint_model_{ii:d}')
|
||||
full_save_name = path + '_last_epoch.pt' if is_last_epoch else path + f'_epoch_{epoch:d}.pt'
|
||||
state = {'epoch': epoch,
|
||||
'model': self.models[ii].state_dict(),
|
||||
'scheduler': self.schedulers[ii].state_dict(),
|
||||
'optimizer': self.schedulers[ii].optimizer.state_dict()}
|
||||
torch.save(state, full_save_name)
|
||||
|
||||
return is_save
|
||||
|
||||
def load_checkpoints(self, restore_scheduler: bool, epoch: Optional[int] = None) -> None:
|
||||
"""
|
||||
If epoch is not specified, latest checkpoint files are loaded to restore the state
|
||||
of model weights and optimisers
|
||||
:param restore_scheduler (bool): Restores the state of optimiser and scheduler from checkpoint
|
||||
:param epoch (int): Training epoch id.
|
||||
"""
|
||||
suffix = f'_epoch_{epoch:d}.pt' if epoch else '_last_epoch.pt'
|
||||
|
||||
for ii in range(len(self.models)):
|
||||
path = str(self.checkpoint_dir / f'checkpoint_model_{ii:d}') + suffix
|
||||
logging.info(f"Loading model-{ii} from checkpoint:\n {path}")
|
||||
state = torch.load(str(path))
|
||||
|
||||
self.models[ii].load_state_dict(state['model'])
|
||||
if restore_scheduler:
|
||||
self.schedulers[ii].load_state_dict(state['scheduler'])
|
||||
self.schedulers[ii].optimizer.load_state_dict(state['optimizer'])
|
||||
if epoch is not None:
|
||||
logging.info(f"Model is loaded from epoch: {epoch}")
|
||||
assert state['epoch'] == epoch
|
||||
|
||||
if self.ema_models:
|
||||
for ii in range(len(self.ema_models)):
|
||||
path = str(self.checkpoint_dir / f'checkpoint_ema_model_{ii:d}') + suffix
|
||||
logging.info(f"Loading ema teacher model-{ii} from checkpoint:\n {path}")
|
||||
self.ema_models[ii].restore_from_checkpoint(path)
|
||||
|
||||
def run_training(self) -> None:
|
||||
"""
|
||||
Perform model training.
|
||||
Model/s specified in config are trained for `num_epoch` epochs and results are stored in tf events.
|
||||
"""
|
||||
num_epochs = self.config.scheduler.epochs
|
||||
epoch_range = range(num_epochs)
|
||||
if self.config.train.resume_epoch > 0:
|
||||
resume_epoch = self.config.train.resume_epoch
|
||||
epoch_range = range(resume_epoch + 1, num_epochs)
|
||||
self.load_checkpoints(restore_scheduler=self.config.train.restore_scheduler, epoch=resume_epoch)
|
||||
|
||||
# Model evaluation - startup
|
||||
logging.info("Running evaluation on the validation set before training ...")
|
||||
self.run_epoch(self.val_loader, is_train=False, epoch=0)
|
||||
self.val_trackers[0].log_epoch_and_reset(epoch=0)
|
||||
|
||||
# Model training loop
|
||||
for epoch in epoch_range:
|
||||
epoch_start = time.time()
|
||||
logging.info('\n' + f'Epoch {epoch:d}')
|
||||
self.run_epoch(self.train_loader, is_train=True, epoch=epoch)
|
||||
self.run_epoch(self.val_loader, is_train=False, epoch=epoch)
|
||||
self.save_checkpoint(epoch)
|
||||
|
||||
for _t in self.all_trackers:
|
||||
_t.log_epoch_and_reset(epoch)
|
||||
logging.info('Epoch time: {0:.2f} secs'.format(time.time() - epoch_start))
|
||||
|
||||
# Store loss values for post-training analysis
|
||||
for _t in self.all_trackers:
|
||||
_t.save_loss()
|
||||
|
||||
def run_inference(self, dataloader: Any, use_mc_sampling: bool = False) -> List[MetricTracker]:
|
||||
"""
|
||||
Deployment of pre-trained model.
|
||||
"""
|
||||
dataset = dataloader.dataset
|
||||
trackers = [MetricTracker(os.path.join(self.config.train.output_dir, f'model_{i:d}_inference'),
|
||||
num_samples_total=len(dataset),
|
||||
num_samples_per_epoch=len(dataset),
|
||||
name=f"model_{i}_inference",
|
||||
num_epochs=1,
|
||||
num_classes=self.config.dataset.n_classes,
|
||||
save_tf_events=False) for i in range(len(self.models))]
|
||||
|
||||
for model in self.models:
|
||||
model.eval()
|
||||
if use_mc_sampling:
|
||||
logging.info("Applying MC sampling at inference time.")
|
||||
dropout_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.Dropout)]
|
||||
for _layer in dropout_layers:
|
||||
_layer.training = True
|
||||
|
||||
with torch.no_grad():
|
||||
for indices, images, labels in dataloader:
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
outputs = self.forward(images, requires_grad=False)
|
||||
losses = self.compute_loss(outputs, labels, indices=None)
|
||||
if not isinstance(losses, List):
|
||||
losses = [losses]
|
||||
|
||||
# Log training and validation stats in metric tracker
|
||||
for i, (logits, loss) in enumerate(zip(outputs, losses)):
|
||||
trackers[i].sample_metrics.append_batch(epoch=0,
|
||||
logits=logits.detach(),
|
||||
labels=labels.detach(),
|
||||
loss=loss.loss.item(),
|
||||
indices=indices.cpu().tolist(),
|
||||
per_sample_loss=loss.per_sample_loss.detach())
|
||||
return trackers
|
|
@ -0,0 +1,80 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
from InnerEyeDataQuality.deep_learning.collect_embeddings import get_all_embeddings
|
||||
from InnerEyeDataQuality.deep_learning.trainers.model_trainer_base import IndexContainer, Loss, ModelTrainer
|
||||
from InnerEyeDataQuality.deep_learning.utils import create_model
|
||||
from InnerEyeDataQuality.deep_learning.loss import tanh_loss
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from InnerEyeDataQuality.deep_learning.loss import CrossEntropyLoss
|
||||
|
||||
|
||||
class VanillaTrainer(ModelTrainer):
|
||||
"""
|
||||
Implements vanilla cross entropy training with one model
|
||||
"""
|
||||
|
||||
def __init__(self, config: ConfigNode):
|
||||
super().__init__(config)
|
||||
self.tanh_regularisation = config.train.tanh_regularisation
|
||||
self.loss_fn = CrossEntropyLoss(config)
|
||||
|
||||
def get_models(self, config: ConfigNode) -> List[torch.nn.Module]:
|
||||
"""
|
||||
:param config: The job config
|
||||
:return: A list with one model to be trained
|
||||
"""
|
||||
return [create_model(config, model_id=0)]
|
||||
|
||||
def compute_loss(self, outputs: List[torch.Tensor], labels: torch.Tensor,
|
||||
indices: Optional[Tuple[IndexContainer, IndexContainer]] = None) -> Loss:
|
||||
"""
|
||||
Implements the standard cross-entropy loss using one model
|
||||
:param outputs: A list of logits outputed by each model
|
||||
:param labels: The target labels
|
||||
:return: A list of Loss object, each element contains the loss that is fed to the optimizer and a
|
||||
tensor of per sample losses
|
||||
"""
|
||||
logits = outputs[0]
|
||||
per_sample_loss = self.loss_fn(predictions=logits, targets=labels, reduction='none')
|
||||
if self.weight is not None:
|
||||
per_sample_loss *= self.weight[labels]
|
||||
loss = torch.mean(per_sample_loss)
|
||||
|
||||
if self.tanh_regularisation != 0.0:
|
||||
loss += self.tanh_regularisation * tanh_loss(logits)
|
||||
|
||||
return Loss(per_sample_loss, loss)
|
||||
|
||||
def run_epoch(self, dataloader: DataLoader, epoch: int, is_train: bool = False) -> None:
|
||||
"""
|
||||
Run a training or validation epoch of the base model trainer but also step the forget rate scheduler
|
||||
:param dataloader: A dataloader object
|
||||
:param epoch: Current epoch id.
|
||||
:param is_train: Whether this is a training epoch or not
|
||||
:param run_inference_on_training_set: If True, record all metrics using the train_trackers
|
||||
(even if is_train = False)
|
||||
:return:
|
||||
"""
|
||||
for model in self.models:
|
||||
model.train() if is_train else model.eval()
|
||||
|
||||
for indices, images, labels in dataloader:
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
outputs = self.forward(images, requires_grad=is_train)
|
||||
embeddings = get_all_embeddings(self.all_model_cnn_embeddings)[0]
|
||||
losses = self.compute_loss(outputs, labels)
|
||||
if is_train:
|
||||
self.step_optimizers([losses])
|
||||
|
||||
# Log training and validation stats in metric tracker
|
||||
tracker = self.train_trackers[0] if is_train else self.val_trackers[0]
|
||||
tracker.sample_metrics.append_batch(epoch, outputs[0].detach(), labels.detach(), losses.loss.item(),
|
||||
indices.cpu().tolist(), losses.per_sample_loss.detach(), embeddings)
|
|
@ -0,0 +1,185 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import random
|
||||
from typing import Any, Callable, Tuple
|
||||
|
||||
import PIL
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from scipy.ndimage import gaussian_filter, map_coordinates
|
||||
from skimage.filters import rank
|
||||
from skimage.morphology import disk
|
||||
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
|
||||
|
||||
class BaseTransform:
|
||||
def __init__(self, config: ConfigNode):
|
||||
self.transform = lambda x: x
|
||||
|
||||
def __call__(self, data: PIL.Image.Image) -> PIL.Image.Image:
|
||||
return self.transform(data)
|
||||
|
||||
|
||||
class Standardize:
|
||||
def __init__(self, mean: np.ndarray, std: np.ndarray):
|
||||
self.mean = np.array(mean)
|
||||
self.std = np.array(std)
|
||||
|
||||
def __call__(self, image: PIL.Image.Image) -> np.ndarray:
|
||||
image = np.asarray(image).astype(np.float32) / 255.
|
||||
image = (image - self.mean) / self.std
|
||||
return image
|
||||
|
||||
|
||||
class CenterCrop(BaseTransform):
|
||||
def __init__(self, config: ConfigNode):
|
||||
self.transform = torchvision.transforms.CenterCrop(config.preprocess.center_crop_size)
|
||||
|
||||
|
||||
class RandomCrop(BaseTransform):
|
||||
def __init__(self, config: ConfigNode):
|
||||
self.transform = torchvision.transforms.RandomCrop(
|
||||
config.dataset.image_size,
|
||||
padding=config.augmentation.random_crop.padding,
|
||||
fill=config.augmentation.random_crop.fill,
|
||||
padding_mode=config.augmentation.random_crop.padding_mode)
|
||||
|
||||
|
||||
class RandomResizeCrop(BaseTransform):
|
||||
def __init__(self, config: ConfigNode):
|
||||
self.transform = torchvision.transforms.RandomResizedCrop(
|
||||
size=config.preprocess.resize,
|
||||
scale=config.augmentation.random_crop.scale)
|
||||
|
||||
|
||||
class RandomHorizontalFlip(BaseTransform):
|
||||
def __init__(self, config: ConfigNode):
|
||||
self.transform = torchvision.transforms.RandomHorizontalFlip(
|
||||
config.augmentation.random_horizontal_flip.prob)
|
||||
|
||||
|
||||
class RandomAffine(BaseTransform):
|
||||
def __init__(self, config: ConfigNode):
|
||||
self.transform = torchvision.transforms.RandomAffine(degrees=config.augmentation.random_affine.max_angle, # 15
|
||||
translate=(
|
||||
config.augmentation.random_affine.max_horizontal_shift,
|
||||
config.augmentation.random_affine.max_vertical_shift),
|
||||
shear=config.augmentation.random_affine.max_shear)
|
||||
|
||||
|
||||
class Resize(BaseTransform):
|
||||
def __init__(self, config: ConfigNode):
|
||||
self.transform = torchvision.transforms.Resize(config.preprocess.resize)
|
||||
|
||||
|
||||
class RandomColorJitter(BaseTransform):
|
||||
def __init__(self, config: ConfigNode) -> None:
|
||||
self.transform = torchvision.transforms.ColorJitter(brightness=config.augmentation.random_color.brightness,
|
||||
contrast=config.augmentation.random_color.contrast,
|
||||
saturation=config.augmentation.random_color.saturation)
|
||||
|
||||
|
||||
class RandomErasing(BaseTransform):
|
||||
def __init__(self, config: ConfigNode) -> None:
|
||||
self.transform = torchvision.transforms.RandomErasing(p=0.5,
|
||||
scale=config.augmentation.random_erasing.scale,
|
||||
ratio=config.augmentation.random_erasing.ratio)
|
||||
|
||||
|
||||
class RandomGamma(BaseTransform):
|
||||
def __init__(self, config: ConfigNode) -> None:
|
||||
self.min, self.max = config.augmentation.gamma.scale
|
||||
|
||||
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
gamma = random.uniform(self.min, self.max)
|
||||
return torchvision.transforms.functional.adjust_gamma(image, gamma=gamma)
|
||||
|
||||
|
||||
class HistogramNormalization:
|
||||
def __init__(self, config: ConfigNode) -> None:
|
||||
self.disk_size = config.preprocess.histogram_normalization.disk_size
|
||||
|
||||
def __call__(self, image: PIL.Image.Image) -> np.ndarray:
|
||||
# Apply local histogram equalization
|
||||
image = np.array(image)
|
||||
return PIL.Image.fromarray(rank.equalize(image, selem=disk(self.disk_size)))
|
||||
|
||||
|
||||
class ExpandChannels:
|
||||
def __call__(self, data: torch.Tensor) -> torch.Tensor:
|
||||
return torch.repeat_interleave(data, 3, dim=0)
|
||||
|
||||
|
||||
class ToNumpy:
|
||||
def __call__(self, image: PIL.Image.Image) -> np.ndarray:
|
||||
return np.array(image)
|
||||
|
||||
|
||||
class AddGaussianNoise:
|
||||
def __init__(self, config: ConfigNode) -> None:
|
||||
"""
|
||||
Transformation to add Gaussian noise N(0, std) to
|
||||
an image. Where std is set with the config.augmentation.gaussian_noise.std
|
||||
argument. The transformation will be applied with probability
|
||||
config.augmentation.gaussian_noise.p_apply
|
||||
"""
|
||||
self.std = config.augmentation.gaussian_noise.std
|
||||
self.p_apply = config.augmentation.gaussian_noise.p_apply
|
||||
|
||||
def __call__(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if np.random.random(1) > self.p_apply:
|
||||
return data
|
||||
noise = torch.randn(size=data.shape) * self.std
|
||||
data = torch.clamp(data + noise, 0, 1)
|
||||
return data
|
||||
|
||||
|
||||
class ElasticTransform:
|
||||
"""Elastic deformation of images as described in [Simard2003]_.
|
||||
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
|
||||
Convolutional Neural Networks applied to Visual Document Analysis", in
|
||||
Proc. of the International Conference on Document Analysis and
|
||||
Recognition, 2003.
|
||||
|
||||
https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.160.8494&rep=rep1&type=pdf
|
||||
|
||||
:param sigma: elasticity coefficient
|
||||
:param alpha: intensity of the deformation
|
||||
:param p_apply: probability of applying the transformation
|
||||
"""
|
||||
|
||||
def __init__(self, config: ConfigNode) -> None:
|
||||
self.alpha = config.augmentation.elastic_transform.alpha
|
||||
self.sigma = config.augmentation.elastic_transform.sigma
|
||||
self.p_apply = config.augmentation.elastic_transform.p_apply
|
||||
|
||||
def __call__(self, image: PIL.Image) -> PIL.Image:
|
||||
if np.random.random(1) > self.p_apply:
|
||||
return image
|
||||
image = np.asarray(image).squeeze()
|
||||
assert len(image.shape) == 2
|
||||
shape = image.shape
|
||||
|
||||
dx = gaussian_filter((np.random.random(shape) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
|
||||
dy = gaussian_filter((np.random.random(shape) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
|
||||
|
||||
x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
|
||||
indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))
|
||||
return PIL.Image.fromarray(map_coordinates(image, indices, order=1).reshape(shape))
|
||||
|
||||
|
||||
class DualViewTransformWrapper:
|
||||
def __init__(self, transforms: Callable):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, sample: PIL.Image.Image) -> Tuple[Any, Any]:
|
||||
transform = self.transforms
|
||||
xi = transform(sample)
|
||||
xj = transform(sample)
|
||||
return xi, xj
|
|
@ -0,0 +1,154 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import collections
|
||||
import logging
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from default_paths import PROJECT_ROOT_DIR
|
||||
from InnerEyeDataQuality.configs.config_node import ConfigNode
|
||||
from InnerEyeDataQuality.configs import model_config
|
||||
from InnerEyeDataQuality.configs.selector_config import get_default_selector_config
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.configs import ssl_model_config
|
||||
from InnerEyeDataQuality.deep_learning.self_supervised.utils import create_ssl_image_classifier
|
||||
from InnerEyeDataQuality.utils.generic import setup_cudnn, get_train_output_dir
|
||||
|
||||
|
||||
def get_run_config(config: ConfigNode, run_seed: int) -> ConfigNode:
|
||||
config_run = config.clone()
|
||||
config_run.defrost()
|
||||
config_run.train.seed = run_seed
|
||||
config_run.train.output_dir = get_train_output_dir(config)
|
||||
config_run.freeze()
|
||||
return config_run
|
||||
|
||||
|
||||
def load_ssl_model_config(config_path: Path) -> ConfigNode:
|
||||
'''
|
||||
Loads configs required for self supervised learning. Does not setup cudann as this is being
|
||||
taken care of by lightining.
|
||||
'''
|
||||
config = ssl_model_config.get_default_model_config()
|
||||
config.merge_from_file(config_path)
|
||||
update_model_config(config)
|
||||
|
||||
# Freeze config entries
|
||||
config.freeze()
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def load_model_config(config_path: Path) -> ConfigNode:
|
||||
'''
|
||||
Loads configs required for model training and inference.
|
||||
'''
|
||||
config = model_config.get_default_model_config()
|
||||
config.merge_from_file(config_path)
|
||||
update_model_config(config)
|
||||
setup_cudnn(config)
|
||||
|
||||
# Freeze config entries
|
||||
config.freeze()
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def update_model_config(config: ConfigNode) -> ConfigNode:
|
||||
'''
|
||||
Adds dataset specific parameters in model config
|
||||
'''
|
||||
if config.dataset.name in ['CIFAR10', 'CIFAR100']:
|
||||
dataset_dir = f'~/.torch/datasets/{config.dataset.name}'
|
||||
config.dataset.dataset_dir = dataset_dir
|
||||
config.dataset.image_size = 32
|
||||
config.dataset.n_channels = 3
|
||||
config.dataset.n_classes = int(config.dataset.name[5:])
|
||||
elif config.dataset.name in ['MNIST']:
|
||||
dataset_dir = '~/.torch/datasets'
|
||||
config.dataset.dataset_dir = dataset_dir
|
||||
config.dataset.image_size = 28
|
||||
config.dataset.n_channels = 1
|
||||
config.dataset.n_classes = 10
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
config.device = 'cpu'
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def override_config(source: ConfigNode, overrides: ConfigNode) -> ConfigNode:
|
||||
'''
|
||||
Overrides the keys and values present in node `overrides` into source object recursively.
|
||||
'''
|
||||
for key, value in overrides.items():
|
||||
if isinstance(value, collections.Mapping) and value:
|
||||
returned = override_config(source.get(key, {}), value) # type: ignore
|
||||
returned = CfgNode(returned)
|
||||
source[key] = returned
|
||||
else:
|
||||
source[key] = overrides[key]
|
||||
return source
|
||||
|
||||
|
||||
def load_selector_config(config_path: str) -> ConfigNode:
|
||||
"""
|
||||
Loads a selector config and merges with its model config
|
||||
"""
|
||||
selector_config = get_default_selector_config()
|
||||
selector_config.merge_from_file(config_path)
|
||||
model_config = load_model_config(PROJECT_ROOT_DIR / selector_config.selector.model_config_path)
|
||||
merged_config = override_config(source=model_config, overrides=selector_config)
|
||||
merged_config.freeze()
|
||||
|
||||
return merged_config
|
||||
|
||||
|
||||
def create_logger(output_dir: Union[str, Path]) -> None:
|
||||
if isinstance(output_dir, str):
|
||||
output_dir = Path(output_dir)
|
||||
log_path = output_dir.absolute() / 'training.log'
|
||||
logging.basicConfig(filename=log_path,
|
||||
filemode='w',
|
||||
format='%(asctime)s %(name)-4s %(levelname)-6s %(message)s',
|
||||
datefmt='%m-%d %H:%M',
|
||||
level=logging.DEBUG)
|
||||
|
||||
# define a Handler which writes INFO messages or higher to the sys.stderr
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.INFO)
|
||||
|
||||
# set a format which is simpler for console use
|
||||
formatter = logging.Formatter('%(name)-4s: %(levelname)-6s %(message)s')
|
||||
|
||||
# tell the handler to use this format
|
||||
console.setFormatter(formatter)
|
||||
|
||||
# add the handler to the root logger
|
||||
logging.getLogger().addHandler(console)
|
||||
|
||||
|
||||
def create_model(config: ConfigNode, model_id: Optional[int]) -> torch.nn.Module:
|
||||
device = torch.device(config.device)
|
||||
if config.train.use_self_supervision:
|
||||
assert isinstance(model_id, int) and model_id < 2
|
||||
model_checkpoint_name = config.train.self_supervision.checkpoints[model_id]
|
||||
model_checkpoint_path = PROJECT_ROOT_DIR / model_checkpoint_name
|
||||
model = create_ssl_image_classifier(num_classes=config.dataset.n_classes,
|
||||
pl_checkpoint_path=str(model_checkpoint_path),
|
||||
freeze_encoder=config.train.self_supervision.freeze_encoder)
|
||||
else:
|
||||
try:
|
||||
module = import_module('PyTorchImageClassification.models'f'.{config.model.type}.{config.model.name}')
|
||||
except ModuleNotFoundError:
|
||||
module = import_module(
|
||||
f'InnerEyeDataQuality.deep_learning.architectures.{config.model.type}.{config.model.name}')
|
||||
model = getattr(module, 'Network')(config)
|
||||
return model.to(device)
|
|
@ -0,0 +1,77 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_label_entropy(label_counts: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
:param label_counts: Input label histogram (n_samples x n_classes)
|
||||
"""
|
||||
label_distribution = label_counts / np.sum(label_counts, axis=-1, keepdims=True)
|
||||
return cross_entropy(label_distribution, label_distribution)
|
||||
|
||||
|
||||
def compute_accuracy(source: np.ndarray, target: np.ndarray) -> float:
|
||||
"""
|
||||
Computes the agreement rate between two tensors
|
||||
:param source: source array (n_samples, n_classes)
|
||||
:param target: target array (n_samples, n_classes)
|
||||
"""
|
||||
# Compare single label case against all available labels
|
||||
source = np.argmax(source, axis=1)
|
||||
target = np.argmax(target, axis=1)
|
||||
|
||||
# Accuracy
|
||||
acc = 100.0 * np.sum(source == target) / source.size
|
||||
|
||||
return acc
|
||||
|
||||
def compute_model_disagreement_score(posteriors: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Measure model disagreement score (Ref BALD)
|
||||
:param posteriors: numpy array (shape: (model_candidates, batch_size, num_classes))
|
||||
:return: Disagreement score (BALD) for each sample (shape: (batch))
|
||||
"""
|
||||
def _entropy(x: np.ndarray, log_base: int, epsilon: float = 1e-12) -> np.ndarray:
|
||||
return -np.sum(x * (np.log(x + epsilon) / np.log(log_base)), axis=-1)
|
||||
num_classes = int(posteriors.shape[-1])
|
||||
avg_posteriors = np.mean(posteriors, axis=0)
|
||||
avg_entropy = _entropy(avg_posteriors, num_classes)
|
||||
exp_conditional_entropy = np.mean(_entropy(posteriors, num_classes), axis=0)
|
||||
bald_score = avg_entropy - exp_conditional_entropy
|
||||
return bald_score
|
||||
|
||||
|
||||
def cross_entropy(predicted_distribution: np.ndarray, target_distribution: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute the normalised cross-entropy between the predicted and target distributions
|
||||
:param predicted_distribution: Predicted distribution shape = (num_samples, num_classes)
|
||||
:param target_distribution: Target distribution shape = (num_samples, num_classes)
|
||||
:return: The cross-entropy for each sample
|
||||
"""
|
||||
num_classes = predicted_distribution.shape[1]
|
||||
return -np.sum(target_distribution * np.log(predicted_distribution + 1e-12) / np.log(num_classes), axis=-1)
|
||||
|
||||
def max_prediction_error(predicted_distribution: np.ndarray, target_distribution: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute the max (class-wise) prediction error between the predicted and target distributions
|
||||
:param predicted_distribution: Predicted distribution shape = (num_samples, num_classes)
|
||||
:param target_distribution: Target distribution shape = (num_samples, num_classes)
|
||||
:return: The max (class-wise) prediction error for each sample
|
||||
"""
|
||||
current_target_class = np.argmax(target_distribution, axis=1)
|
||||
current_target_pred_prob = predicted_distribution[range(len(current_target_class)), current_target_class]
|
||||
prediction_errors = 1.0 - current_target_pred_prob
|
||||
return prediction_errors
|
||||
|
||||
def total_variation(predicted_distribution: np.ndarray, target_distribution: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute the total variation error between the predicted and target distributions
|
||||
:param predicted_distribution: Predicted distribution shape = (num_samples, num_classes)
|
||||
:param target_distribution: Target distribution shape = (num_samples, num_classes)
|
||||
:return: The total variation for each sample
|
||||
"""
|
||||
return np.sum(np.abs(predicted_distribution - target_distribution), axis=-1) / 2.
|
|
@ -0,0 +1,408 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
||||
from itertools import cycle
|
||||
from matplotlib import pyplot as plt
|
||||
from pathlib import Path
|
||||
from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_auc_score, roc_curve
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from default_paths import MAIN_SIMULATION_DIR
|
||||
from InnerEyeDataQuality.selection.simulation_statistics import SimulationStatsDistribution
|
||||
from InnerEyeDataQuality.utils.generic import create_folder, save_obj
|
||||
|
||||
|
||||
def _plot_mean_and_confidence_intervals(ax: plt.Axes, y: np.ndarray, color: str, label: str, linestyle: str = '-',
|
||||
clip: Tuple[float, float] = (-np.inf, np.inf),
|
||||
use_derivative: bool = False,
|
||||
include_auc: bool = True,
|
||||
linewidth: float = 3.0) -> None:
|
||||
if use_derivative:
|
||||
y = (y - y[:, 0][:, np.newaxis]) / (np.arange(y.shape[1]) + 1)
|
||||
|
||||
x = np.arange(0, y.shape[1])
|
||||
mean = np.mean(y, axis=0)
|
||||
total_auc = auc([0, np.max(x)], [100, 100])
|
||||
auc_mean = auc(x, mean) / total_auc
|
||||
label += f" - AUC: {auc_mean:.4f}" if include_auc else ""
|
||||
std_error = np.std(y, axis=0) / np.sqrt(y.shape[0]) * 1.96
|
||||
ax.plot(x, mean, color=color, label=label, linestyle=linestyle, linewidth=linewidth)
|
||||
ax.fill_between(x,
|
||||
np.clip((mean - std_error), a_min=clip[0], a_max=clip[1]),
|
||||
np.clip((mean + std_error), a_min=clip[0], a_max=clip[1]),
|
||||
color=color, alpha=.2)
|
||||
|
||||
|
||||
def plot_stats_for_all_selectors(stats: Dict[str, SimulationStatsDistribution],
|
||||
y_attr_names: List[str],
|
||||
y_attr_labels: List[str],
|
||||
title: str = '',
|
||||
x_label: str = '',
|
||||
y_label: str = '',
|
||||
legend_loc: Union[int, str] = 2,
|
||||
fontsize: int = 12,
|
||||
figsize: Tuple[int, int] = (14, 10),
|
||||
plot_n_not_ambiguous_noise_cases: bool = False,
|
||||
**plot_kwargs: Any) -> Tuple[plt.Figure, plt.Axes]:
|
||||
"""
|
||||
Given the dictionary with SimulationStatsDistribution plot a curve for each attribute in the y_attr_names list for
|
||||
each selector. Total number of curves will be num_selectors * num_y_attr_names.
|
||||
:param stats: A dictionary where each entry corresponds to the SimulationStatsDistribution of each selector.
|
||||
:param y_attr_names: The names of the attributes to plot on the y axis.
|
||||
:param y_attr_labels: The labels for the legend of the attributes to plot on the y axis.
|
||||
:param title: The title of the figure.
|
||||
:param x_label: The title of the x-axis.
|
||||
:param y_label: The title of the y-axis.
|
||||
:param legend_loc: The location of the legend.
|
||||
:param plot_n_not_ambiguous_noise_cases: If True, indicate the number of noise easy cases left at the end of the
|
||||
simulation
|
||||
for each selector. If all cases are selected indicate the corresponding iteration.
|
||||
:return: The figure and axis with all the curves plotted.
|
||||
"""
|
||||
colors = ['blue', 'red', 'orange', 'brown', 'black', 'purple', 'green', 'gray', 'olive', 'cyan', 'yellow', 'pink']
|
||||
linestyles = ['-', '--']
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
# Sort alphabetically
|
||||
ordered_keys = sorted(stats.keys())
|
||||
# Put always oracle and random first (this crashes if oracle / random not there)
|
||||
ordered_keys.remove("Oracle")
|
||||
ordered_keys.remove("Random")
|
||||
ordered_keys.insert(0, "Oracle")
|
||||
ordered_keys.insert(1, "Random")
|
||||
|
||||
for _stat_key, color in zip(ordered_keys, cycle(colors)):
|
||||
_stat = stats[_stat_key]
|
||||
if '(Posterior)' in _stat.name:
|
||||
_stat.name = _stat.name.split('(Posterior)')[0]
|
||||
for y_attr_name, y_attr_label, linestyle in zip(y_attr_names, y_attr_labels, cycle(linestyles)):
|
||||
color = assign_preset_color(name=_stat.name, color=color)
|
||||
_plot_mean_and_confidence_intervals(ax=ax, y=_stat.__getattribute__(y_attr_name), color=color,
|
||||
label=y_attr_label + _stat.name, linestyle=linestyle, **plot_kwargs)
|
||||
if plot_n_not_ambiguous_noise_cases:
|
||||
average_value_attribute = np.mean(_stat.__getattribute__(y_attr_name), axis=0)
|
||||
std_value_attribute = np.std(_stat.__getattribute__(y_attr_name), axis=0)
|
||||
logging.debug(f"Method {_stat.name} - {y_attr_name} std: {std_value_attribute[-1]}")
|
||||
n_average_mislabelled_not_ambiguous = np.mean(_stat.num_remaining_mislabelled_not_ambiguous, axis=0)
|
||||
ax.text(average_value_attribute.size + 1, average_value_attribute[-1],
|
||||
f"{average_value_attribute[-1]:.2f}", fontsize=fontsize - 4)
|
||||
no_mislabelled_not_ambiguous_remaining = np.where(n_average_mislabelled_not_ambiguous == 0)[0]
|
||||
if no_mislabelled_not_ambiguous_remaining.size > 0:
|
||||
idx = np.min(no_mislabelled_not_ambiguous_remaining)
|
||||
ax.scatter(idx + 1, average_value_attribute[idx], color=color, marker="s")
|
||||
ax.grid()
|
||||
ax.set_title(title, fontsize=fontsize)
|
||||
ax.set_xlabel(x_label, fontsize=fontsize)
|
||||
ax.set_ylabel(y_label, fontsize=fontsize)
|
||||
ax.legend(loc=legend_loc, fontsize=fontsize - 2)
|
||||
ax.xaxis.set_tick_params(labelsize=fontsize - 2)
|
||||
ax.yaxis.set_tick_params(labelsize=fontsize - 2)
|
||||
|
||||
return fig, ax
|
||||
|
||||
|
||||
def assign_preset_color(name: str, color: str) -> str:
|
||||
if re.search('oracle', name, re.IGNORECASE):
|
||||
return 'blue'
|
||||
elif re.search('random', name, re.IGNORECASE):
|
||||
return 'red'
|
||||
elif re.search('bald', name, re.IGNORECASE):
|
||||
if re.search('active', name, re.IGNORECASE):
|
||||
return 'green'
|
||||
return 'purple'
|
||||
elif re.search('clean', name, re.IGNORECASE):
|
||||
return 'pink'
|
||||
elif re.search('self-supervision', name, re.IGNORECASE) \
|
||||
or re.search('SSL', name, re.IGNORECASE) or re.search('pretrained', name, re.IGNORECASE):
|
||||
return 'orange'
|
||||
elif re.search('imagenet', name, re.IGNORECASE):
|
||||
return 'green'
|
||||
elif re.search('self-supervision', name, re.IGNORECASE):
|
||||
if re.search('graph', name, re.IGNORECASE):
|
||||
return 'green'
|
||||
elif re.search('With entropy', name, re.IGNORECASE):
|
||||
return 'olive'
|
||||
elif re.search('active', name, re.IGNORECASE):
|
||||
return 'cyan'
|
||||
else:
|
||||
return 'orange'
|
||||
elif re.search('coteaching', name, re.IGNORECASE):
|
||||
return 'brown'
|
||||
elif re.search('less', name, re.IGNORECASE):
|
||||
return 'grey'
|
||||
elif re.search('vanilla', name, re.IGNORECASE):
|
||||
if re.search('active', name, re.IGNORECASE):
|
||||
return 'grey'
|
||||
return 'black'
|
||||
else:
|
||||
return color
|
||||
|
||||
|
||||
def plot_minimal_sampler(stats: Dict[str, SimulationStatsDistribution],
|
||||
n_samples: int, ax: plt.Axes, fontsize: int = 12, legend_loc: int = 2) -> None:
|
||||
"""
|
||||
Starting with one initial label, a noisy sample needs to be relabeled at least twice (best case scenario
|
||||
the sampled class during relabeling is equal to the correct label
|
||||
"""
|
||||
|
||||
n_mislabelled_ambiguous = list(stats.values())[0].num_initial_mislabelled_ambiguous
|
||||
n_mislabelled_not_ambiguous = list(stats.values())[0].num_initial_mislabelled_not_ambiguous
|
||||
|
||||
accuracy_beginning = 100 * float(n_samples - n_mislabelled_ambiguous - n_mislabelled_not_ambiguous) / n_samples
|
||||
ax.plot([0, 2 * (n_mislabelled_not_ambiguous + n_mislabelled_ambiguous)], [accuracy_beginning, 100], linestyle="--",
|
||||
label="Minimal sampler")
|
||||
|
||||
# accuracy_easy_cases = 100 * float(n_samples - n_mislabelled_ambiguous) / n_samples
|
||||
# max_accuracy = max([np.max(_stat.accuracy) for _stat in stats.values()])
|
||||
# plt.scatter(2 * n_mislabelled_not_ambiguous, accuracy_easy_cases,
|
||||
# marker='s', label="No non-ambiguous noise cases left")
|
||||
# ax.set_ylim(accuracy_beginning - 0.25, max_accuracy + 0.25)
|
||||
|
||||
ax.legend(loc=legend_loc, fontsize=fontsize - 2)
|
||||
|
||||
|
||||
def plot_stats(stats: Dict[str, SimulationStatsDistribution],
|
||||
dataset_name: str,
|
||||
n_samples: int,
|
||||
filename_suffix: str,
|
||||
save_path: Optional[Path] = None,
|
||||
sample_indices: Optional[List[int]] = None) -> None:
|
||||
"""
|
||||
:param stats:
|
||||
:param dataset_name:
|
||||
:param n_samples:
|
||||
:param filename_suffix:
|
||||
:param save_path:
|
||||
:param sample_indices: Image indices used in the dataset to visualise the selected cases
|
||||
"""
|
||||
|
||||
if save_path:
|
||||
input_args = locals()
|
||||
save_path = save_path / dataset_name
|
||||
save_path.mkdir(exist_ok=True)
|
||||
save_obj(input_args, save_path / "inputs_to_plot_stats.pkl")
|
||||
|
||||
noise_rate = 100 * float(
|
||||
list(stats.values())[0].num_initial_mislabelled_ambiguous + list(stats.values())[
|
||||
0].num_initial_mislabelled_not_ambiguous) / n_samples
|
||||
# Label accuracy vs num relabels
|
||||
fontsize, legend_loc = 20, 2
|
||||
fig, ax = plot_stats_for_all_selectors(stats, ['accuracy'], [''],
|
||||
title=f'Dataset Curation - {dataset_name} (N={n_samples}, '
|
||||
f'{noise_rate:.1f}% noise)',
|
||||
x_label='Number of collected relabels on the dataset',
|
||||
y_label='Percentage of correct labels',
|
||||
legend_loc=legend_loc,
|
||||
fontsize=fontsize,
|
||||
figsize=(11, 10),
|
||||
plot_n_not_ambiguous_noise_cases=True)
|
||||
if dataset_name != "NoisyChestXray":
|
||||
plot_minimal_sampler(stats, n_samples, ax, fontsize=fontsize, legend_loc=legend_loc)
|
||||
if save_path:
|
||||
fig.savefig(save_path / f"simulation_label_accuracy_{filename_suffix}.pdf", bbox_inches='tight')
|
||||
fig.savefig(save_path / f"simulation_label_accuracy_{filename_suffix}.png", bbox_inches='tight')
|
||||
|
||||
# Label accuracy vs num relabels - To Origin
|
||||
fig, ax = plot_stats_for_all_selectors(stats, ['accuracy'], [''],
|
||||
title=f'Dataset Curation - {dataset_name} (N={n_samples})',
|
||||
x_label='Number of collected relabels on the dataset',
|
||||
y_label='Percentage of correct labels',
|
||||
legend_loc=1,
|
||||
plot_n_not_ambiguous_noise_cases=False,
|
||||
use_derivative=True)
|
||||
if save_path:
|
||||
fig.savefig(save_path / f"simulation_label_accuracy_to_origin_{filename_suffix}.png", bbox_inches='tight')
|
||||
|
||||
# Average total variation vs num relabels
|
||||
fig, ax = plot_stats_for_all_selectors(stats, ['avg_total_variation'], [''],
|
||||
title=f'Dataset Curation - {dataset_name} (N={n_samples})',
|
||||
x_label='Number of collected relabels on the dataset',
|
||||
y_label='Total Variation (Full vs Sampled Distributions)',
|
||||
fontsize=fontsize,
|
||||
include_auc=False,
|
||||
figsize=(10, 11),
|
||||
legend_loc=1)
|
||||
if save_path:
|
||||
fig.savefig(save_path / f"simulation_avg_total_variation_{filename_suffix}.png", bbox_inches='tight')
|
||||
fig.savefig(save_path / f"simulation_avg_total_variation_{filename_suffix}.pdf", bbox_inches='tight')
|
||||
|
||||
# Remaining number of noisy cases vs num relabels
|
||||
fig, ax = plot_stats_for_all_selectors(stats,
|
||||
['num_remaining_mislabelled_not_ambiguous'], [''],
|
||||
x_label='Number of collected relabels on the dataset',
|
||||
y_label='# of remaining clear noisy samples',
|
||||
fontsize=fontsize + 1,
|
||||
figsize=(10, 10),
|
||||
include_auc=False,
|
||||
linewidth=5.0,
|
||||
legend_loc="lower left")
|
||||
if save_path:
|
||||
fig.savefig(save_path / f"simulation_remaining_clear_mislabelled_{filename_suffix}.png", bbox_inches='tight')
|
||||
fig.savefig(save_path / f"simulation_remaining_clear_mislabelled_{filename_suffix}.pdf", bbox_inches='tight')
|
||||
|
||||
# Remaining number ambiguous cases vs num relabels
|
||||
fig, ax = plot_stats_for_all_selectors(stats,
|
||||
['num_remaining_mislabelled_ambiguous'], [''],
|
||||
x_label='Number of collected relabels on the dataset',
|
||||
y_label='# of remaining difficult noisy samples',
|
||||
fontsize=fontsize + 1,
|
||||
include_auc=False,
|
||||
figsize=(10, 10),
|
||||
linewidth=5.0,
|
||||
legend_loc="lower left")
|
||||
if save_path:
|
||||
fig.savefig(save_path / f"simulation_remaining_ambiguous_mislabelled_{filename_suffix}.png",
|
||||
bbox_inches='tight')
|
||||
fig.savefig(save_path / f"simulation_remaining_ambiguous_mislabelled_{filename_suffix}.pdf",
|
||||
bbox_inches='tight')
|
||||
|
||||
|
||||
def plot_roc_curve(scores: np.ndarray, labels: np.ndarray,
|
||||
type_of_cases: str = "mislabelled",
|
||||
ax: Optional[plt.Axes] = None,
|
||||
color: str = "b",
|
||||
legend: str = "AUC",
|
||||
linestyle: str = "-") -> None:
|
||||
fpr, tpr, threshold = roc_curve(labels, scores)
|
||||
roc_auc = roc_auc_score(labels, scores)
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(fpr, tpr, color=color, label=f"{legend}: {roc_auc:.2f}", linestyle=linestyle)
|
||||
ax.set_ylabel("True positive rate")
|
||||
ax.set_xlabel("False positive rate")
|
||||
ax.set_title(f"ROC curve - {type_of_cases} detection")
|
||||
ax.legend(loc="lower right")
|
||||
ax.grid(b=True)
|
||||
|
||||
|
||||
def plot_pr_curve(scores: np.ndarray, labels: np.ndarray,
|
||||
type_of_cases: str = "mislabelled",
|
||||
ax: Optional[plt.Axes] = None,
|
||||
color: str = "b",
|
||||
legend: str = "AUC",
|
||||
linestyle: str = "-") -> None:
|
||||
precision, recall, _ = precision_recall_curve(y_true=labels, probas_pred=scores)
|
||||
pr_auc = auc(recall, precision)
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(recall, precision, color=color, label=f"{legend}: {pr_auc:.2f}", linestyle=linestyle)
|
||||
ax.set_xlim([0.0, 1.05])
|
||||
ax.set_ylim([0.0, 1.05])
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
ax.set_title(f"Precision-Recall curve - {type_of_cases} detection")
|
||||
ax.legend(loc="lower right")
|
||||
ax.grid(b=True)
|
||||
|
||||
|
||||
def plot_binary_confusion(scores: np.ndarray, labels: np.ndarray, type_of_cases: str = "mislabelled",
|
||||
ax: Optional[plt.Axes] = None) -> None:
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots()
|
||||
fpr, tpr, threshold = roc_curve(labels, scores)
|
||||
optimal_threshold = threshold[np.argmax(tpr - fpr)]
|
||||
prediction = scores > optimal_threshold
|
||||
tn, fp, fn, tp = confusion_matrix(labels, prediction).ravel()
|
||||
rates = tn / (tn + tp), fp / (tn + fp), fn / (tp + fn), tp / (tp + fn)
|
||||
cf_matrix = confusion_matrix(labels, prediction)
|
||||
group_names = ["True Neg", "False Pos", "False Neg", "True Pos"]
|
||||
group_counts = [f"{value:.0f}" for value in cf_matrix.flatten()]
|
||||
group_percentages = [f"{value:.3f}" for value in rates]
|
||||
annotations = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in zip(group_names, group_counts, group_percentages)]
|
||||
annotations = np.asarray(annotations).reshape(2, 2)
|
||||
sns.heatmap(cf_matrix, annot=annotations, fmt="", cmap=sns.light_palette("navy", reverse=True), ax=ax,
|
||||
cbar=False)
|
||||
ax.set_xlabel("Prediction")
|
||||
ax.set_ylabel("Actual")
|
||||
ax.set_title(f"Confusion matrix - {type_of_cases} cases\nThreshold = {optimal_threshold:.2f}")
|
||||
|
||||
|
||||
def plot_stats_scores(selector_name: str, scores_mislabelled: np.ndarray, labels_mislabelled: np.ndarray,
|
||||
scores_ambiguous: Optional[np.ndarray] = None, labels_ambiguous: Optional[np.ndarray] = None,
|
||||
save_path: Optional[Path] = None) -> None:
|
||||
if scores_ambiguous is None:
|
||||
fig, ax = plt.subplots(3, 1, figsize=(5, 10))
|
||||
else:
|
||||
fig, ax = plt.subplots(3, 2, figsize=(10, 15))
|
||||
|
||||
fig.suptitle(selector_name)
|
||||
|
||||
ax = ax.ravel(order="F")
|
||||
plot_roc_curve(scores_mislabelled, labels_mislabelled, type_of_cases="mislabelled", ax=ax[0])
|
||||
plot_pr_curve(scores_mislabelled, labels_mislabelled, type_of_cases="mislabelled", ax=ax[1])
|
||||
plot_binary_confusion(scores_mislabelled, labels_mislabelled, type_of_cases="mislabelled", ax=ax[2])
|
||||
|
||||
if scores_ambiguous is not None and labels_ambiguous is not None and np.sum(labels_ambiguous) != 0:
|
||||
plot_roc_curve(scores_ambiguous, labels_ambiguous, type_of_cases="ambiguous", ax=ax[3])
|
||||
plot_pr_curve(scores_ambiguous, labels_ambiguous, type_of_cases="ambiguous", ax=ax[4])
|
||||
plot_binary_confusion(scores_ambiguous, labels_ambiguous, "ambiguous", ax[5])
|
||||
# ambiguous detection given that is mislabelled
|
||||
scores_ambiguous_given_mislabelled = scores_ambiguous[labels_mislabelled == 1]
|
||||
labels_ambiguous_given_mislabelled = labels_ambiguous[labels_mislabelled == 1]
|
||||
if roc_auc_score(labels_ambiguous_given_mislabelled, scores_ambiguous_given_mislabelled) < .5:
|
||||
scores_ambiguous_given_mislabelled *= -1
|
||||
|
||||
plot_roc_curve(scores_ambiguous_given_mislabelled,
|
||||
labels_ambiguous_given_mislabelled,
|
||||
type_of_cases="ambiguous",
|
||||
ax=ax[3],
|
||||
color="red", legend="AUC given mislabelled=True", linestyle="--")
|
||||
plot_pr_curve(scores_ambiguous_given_mislabelled,
|
||||
labels_ambiguous_given_mislabelled,
|
||||
type_of_cases="ambiguous",
|
||||
ax=ax[4],
|
||||
color="red", legend="AUC given mislabelled=True", linestyle="--")
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path / f"{selector_name}_stats_scoring.png", bbox_inches="tight")
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_relabeling_score(true_majority: np.ndarray,
|
||||
starting_scores: np.ndarray,
|
||||
current_majority: np.ndarray,
|
||||
selector_name: str) -> None:
|
||||
"""
|
||||
Plots for ranking of samples score-wise
|
||||
"""
|
||||
create_folder(MAIN_SIMULATION_DIR / "scores_histogram")
|
||||
create_folder(MAIN_SIMULATION_DIR / "noise_detection_heatmaps")
|
||||
|
||||
is_noisy = true_majority != current_majority
|
||||
total_noise_rate = np.mean(is_noisy)
|
||||
|
||||
# Compute how many noisy sampled where present in the
|
||||
# highest n_noisy samples.
|
||||
target_perc = int((1 - total_noise_rate) * 100)
|
||||
q = np.percentile(starting_scores, q=target_perc)
|
||||
noisy_cases_detected = is_noisy[starting_scores > q]
|
||||
percentage_noisy_detected = 100 * float(noisy_cases_detected.sum()) / is_noisy.sum()
|
||||
|
||||
# Plot histogram of scores differentiated by noisy or not.
|
||||
df = pd.DataFrame({"scores": starting_scores,
|
||||
"is_noisy": is_noisy})
|
||||
plt.close()
|
||||
sns.histplot(data=df, x="scores", hue="is_noisy", multiple="dodge", bins=10)
|
||||
plt.title(f"Histogram of relabeling scores {selector_name}\n"
|
||||
f"{percentage_noisy_detected:.1f}% noise cases > {target_perc}th percentile of scores")
|
||||
plt.savefig(MAIN_SIMULATION_DIR / "scores_histogram" / selector_name)
|
||||
plt.close()
|
||||
|
||||
# Plot heatmap showing where the noisy cases are located in the score ranking.
|
||||
idx = np.argsort(starting_scores)
|
||||
sorted_is_noisy = is_noisy[idx]
|
||||
plt.title(f"{selector_name}\n"
|
||||
f"Location of noisy cases by increasing scores (from left to right)\n"
|
||||
f"{percentage_noisy_detected:.1f}% noise cases > {target_perc}th percentile of scores")
|
||||
sns.heatmap(sorted_is_noisy.reshape(1, -1), yticklabels=False, vmax=1.3, cbar=False)
|
||||
plt.savefig(MAIN_SIMULATION_DIR / "noise_detection_heatmaps" / selector_name)
|
||||
plt.close()
|
|
@ -0,0 +1,115 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import datetime
|
||||
import logging
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
|
||||
from itertools import product
|
||||
from typing import Any, Dict, Tuple
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from default_paths import MAIN_SIMULATION_DIR
|
||||
from InnerEyeDataQuality.evaluation.plot_stats import plot_stats
|
||||
from InnerEyeDataQuality.selection.data_curation_utils import get_user_specified_selectors, \
|
||||
update_trainer_for_simulation
|
||||
from InnerEyeDataQuality.selection.selectors.base import SampleSelector
|
||||
from InnerEyeDataQuality.selection.selectors.label_based import (LabelBasedDecisionRule, LabelDistributionBasedSampler,
|
||||
PosteriorBasedSelector)
|
||||
from InnerEyeDataQuality.selection.selectors.bald import BaldSelector
|
||||
from InnerEyeDataQuality.selection.selectors.random_selector import RandomSelector
|
||||
from InnerEyeDataQuality.selection.simulation import DataCurationSimulator
|
||||
from InnerEyeDataQuality.selection.simulation_statistics import SimulationStats, SimulationStatsDistribution
|
||||
from InnerEyeDataQuality.utils.dataset_utils import load_dataset_and_initial_labels_for_simulation
|
||||
from InnerEyeDataQuality.utils.generic import create_folder, get_data_selection_parser, get_logger, set_seed
|
||||
|
||||
|
||||
EXP_OUTPUT_DIR = MAIN_SIMULATION_DIR / datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
dataset, initial_labels = load_dataset_and_initial_labels_for_simulation(args.config[0], args.on_val_set)
|
||||
n_classes = dataset.num_classes
|
||||
n_samples = dataset.num_samples
|
||||
true_distribution = dataset.label_distribution
|
||||
|
||||
user_specified_selectors = get_user_specified_selectors(list_configs=args.config,
|
||||
dataset=dataset,
|
||||
output_path=MAIN_SIMULATION_DIR,
|
||||
plot_embeddings=args.plot_embeddings)
|
||||
|
||||
# Data selection simulations for annotation
|
||||
default_selector = {
|
||||
'Random': RandomSelector(n_samples, n_classes, name='Random'),
|
||||
'Oracle': PosteriorBasedSelector(true_distribution.distribution, n_samples,
|
||||
num_classes=n_classes,
|
||||
name='Oracle',
|
||||
allow_repeat_samples=True,
|
||||
decision_rule=LabelBasedDecisionRule.INV)}
|
||||
sample_selectors = {**user_specified_selectors, **default_selector}
|
||||
|
||||
# Benchmark 2
|
||||
# Determine the number of simulation iterations based on the noise rate
|
||||
expected_noise_rate = np.mean(np.argmax(true_distribution.distribution, -1) != dataset.targets[:n_samples])
|
||||
relabel_budget = int(min(n_samples * expected_noise_rate, n_samples) * 0.35)
|
||||
if dataset.name == "NoisyChestXray":
|
||||
relabel_budget = min(int(n_samples * expected_noise_rate * 2.5), n_samples)
|
||||
else:
|
||||
relabel_budget = min(int(n_samples * expected_noise_rate * 3.0), n_samples)
|
||||
|
||||
logging.info(f"Expected noise rate {expected_noise_rate} - Allocated relabelling budget {relabel_budget}")
|
||||
|
||||
# Setup the simulation function.
|
||||
def _run_simulation_for_selector(name: str,
|
||||
seed: int,
|
||||
sample_selector: SampleSelector) -> Tuple[str, SimulationStats]:
|
||||
if isinstance(sample_selector, (LabelDistributionBasedSampler, BaldSelector)):
|
||||
update_trainer_for_simulation(sample_selector, seed=seed)
|
||||
simulator = DataCurationSimulator(initial_labels=copy.deepcopy(initial_labels),
|
||||
label_distribution=copy.deepcopy(true_distribution),
|
||||
relabel_budget=relabel_budget,
|
||||
name=name,
|
||||
seed=seed,
|
||||
sample_selector=copy.deepcopy(sample_selector))
|
||||
simulator.run_simulation()
|
||||
if sample_selector.output_directory is not None:
|
||||
simulator.save_simulator_results(MAIN_SIMULATION_DIR / sample_selector.output_directory / f"seed_{seed}")
|
||||
return name, simulator.global_stats
|
||||
|
||||
# Run the simulation over multiple seeds and selectors
|
||||
simulation_iter = product(sample_selectors.items(), args.seeds)
|
||||
if args.debug:
|
||||
parallel_output = [_run_simulation_for_selector(_name, _seed, _sel) for (_name, _sel), _seed in simulation_iter]
|
||||
else:
|
||||
num_jobs = min(len(sample_selectors) * len(args.seeds), multiprocessing.cpu_count())
|
||||
parallel = Parallel(n_jobs=num_jobs)
|
||||
parallel_output = parallel(delayed(_run_simulation_for_selector)(_name, _seed, _selector)
|
||||
for (_name, _selector), _seed in simulation_iter)
|
||||
|
||||
# Aggregate parallel output arrays
|
||||
global_stats: Dict[str, Any] = {name: list() for name in set([name for name, _ in parallel_output])}
|
||||
[global_stats[name].append(stats) for name, stats in parallel_output]
|
||||
|
||||
# Analyse simulation stats
|
||||
stats_dist = {name: SimulationStatsDistribution(stats) for name, stats in global_stats.items()}
|
||||
plot_filename_suffix = "val" if args.on_val_set else "train"
|
||||
plot_stats(stats_dist,
|
||||
dataset_name=dataset.name,
|
||||
n_samples=n_samples,
|
||||
save_path=EXP_OUTPUT_DIR,
|
||||
filename_suffix=plot_filename_suffix)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_folder(EXP_OUTPUT_DIR)
|
||||
get_logger(EXP_OUTPUT_DIR / 'dataread.log')
|
||||
parser = get_data_selection_parser()
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
set_seed(seed=12345)
|
||||
main(args)
|
|
@ -0,0 +1,146 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from default_paths import MAIN_SIMULATION_DIR, MODEL_SELECTION_BENCHMARK_DIR, FIGURE_DIR
|
||||
from InnerEyeDataQuality.datasets.cifar10_utils import get_cifar10_label_names
|
||||
from InnerEyeDataQuality.deep_learning.model_inference import inference_ensemble
|
||||
from InnerEyeDataQuality.deep_learning.utils import create_logger, load_model_config, load_selector_config
|
||||
from InnerEyeDataQuality.selection.selectors.label_based import cross_entropy
|
||||
from InnerEyeDataQuality.selection.simulation import DataCurationSimulator
|
||||
from InnerEyeDataQuality.selection.simulation_statistics import get_ambiguous_sample_ids
|
||||
from InnerEyeDataQuality.utils.dataset_utils import get_datasets
|
||||
from InnerEyeDataQuality.utils.generic import create_folder
|
||||
from InnerEyeDataQuality.utils.plot import plot_confusion_matrix
|
||||
|
||||
|
||||
def get_rank(array: Union[np.ndarray, List]) -> int:
|
||||
"""
|
||||
Returns the ranking of an array where the highest value has a rank of 1 and
|
||||
the lowest value has the highest rank.
|
||||
"""
|
||||
return len(array) - np.argsort(array).argsort()
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
|
||||
# Parameters
|
||||
number_of_runs = 1
|
||||
evaluate_on_ambiguous_samples = False
|
||||
|
||||
# Create the evaluation dataset - Make sure that it's the same dataset for all configs
|
||||
assert isinstance(args.config, list)
|
||||
_, dataset = get_datasets(load_model_config(args.config[0]),
|
||||
use_augmentation=False,
|
||||
use_noisy_labels_for_validation=True,
|
||||
use_fixed_labels=False if number_of_runs > 1 else True)
|
||||
|
||||
# Choose a subset of the dataset
|
||||
if evaluate_on_ambiguous_samples:
|
||||
ind = get_ambiguous_sample_ids(dataset.label_counts, threshold=0.10) # type: ignore
|
||||
else:
|
||||
ind = range(len(dataset)) # type: ignore
|
||||
|
||||
# If specified load curated dataset labels
|
||||
curated_target_labels = dict()
|
||||
for cfg_path in args.curated_label_config if args.curated_label_config else list():
|
||||
cfg = load_selector_config(cfg_path)
|
||||
search_dir = MAIN_SIMULATION_DIR / cfg.selector.output_directory
|
||||
targets_list = list()
|
||||
for _f in search_dir.glob('**/*.hdf'):
|
||||
_label_counts = DataCurationSimulator.load_simulator_results(_f)
|
||||
_targets = np.argmax(_label_counts, axis=1)
|
||||
targets_list.append(_targets)
|
||||
curated_target_labels[str(Path(cfg_path).stem)] = targets_list
|
||||
|
||||
# Define class labels for noisy, clean and curated datasets
|
||||
df_rows_list = []
|
||||
metric_names = ["accuracy", "top_n_accuracy", "cross_entropy", "accuracy_per_class"]
|
||||
|
||||
# Run the same experiment multiple time
|
||||
for _run_id in range(number_of_runs):
|
||||
target_labels = {"clean": dataset.clean_targets[ind], # type: ignore
|
||||
"noisy": np.array([dataset.__getitem__(_i)[2] for _i in ind]),
|
||||
**{_n: _l[_run_id] for _n, _l in curated_target_labels.items()}}
|
||||
|
||||
# Loops over different models
|
||||
for config_id, config in enumerate([load_model_config(cfg) for cfg in args.config]):
|
||||
posteriors = inference_ensemble(dataset, config)[1][ind]
|
||||
|
||||
# Collect metrics
|
||||
for _label_name, _label in target_labels.items():
|
||||
df_row = {"model": Path(args.config[config_id]).stem, "run_id": _run_id,
|
||||
"dataset": _label_name, "count": _label.size}
|
||||
for _metric_name in metric_names:
|
||||
_val = benchmark_metrics(posteriors, observed_labels=_label, metric_name=_metric_name,
|
||||
true_labels=target_labels["clean"])
|
||||
df_row.update({_metric_name: _val}) # type: ignore
|
||||
df_rows_list.append(df_row)
|
||||
|
||||
df = pd.DataFrame(df_rows_list)
|
||||
df = df.sort_values(by=["dataset", "model"], axis=0)
|
||||
logging.info(f"\n{df.to_string()}")
|
||||
|
||||
# Aggregate multiple runs
|
||||
group_cols = ['model', 'dataset']
|
||||
df_grouped = df.groupby(group_cols, as_index=False)['accuracy', 'count', 'cross_entropy'].agg([np.mean, np.std])
|
||||
logging.info(f"\n{df_grouped.to_string()}")
|
||||
|
||||
# Plot the observed confusion matrix
|
||||
plot_confusion_matrix(target_labels["clean"], target_labels["noisy"],
|
||||
get_cifar10_label_names(), save_path=FIGURE_DIR)
|
||||
|
||||
|
||||
def benchmark_metrics(posteriors: np.ndarray,
|
||||
observed_labels: np.ndarray,
|
||||
metric_name: str,
|
||||
true_labels: np.ndarray) -> Union[float, List[float]]:
|
||||
"""
|
||||
Defines metrics to be used in model comparison.
|
||||
"""
|
||||
predictions = np.argmax(posteriors, axis=1)
|
||||
|
||||
# Accuracy averaged across all classes
|
||||
if metric_name == "accuracy":
|
||||
return np.mean(predictions == observed_labels) * 100.0
|
||||
# Cross-entropy loss across all samples
|
||||
elif metric_name == "top_n_accuracy":
|
||||
N = 2
|
||||
sorted_class_predictions = np.argsort(posteriors, axis=1)[:, ::-1]
|
||||
correct = int(0)
|
||||
for _i in range(observed_labels.size):
|
||||
correct += np.any(sorted_class_predictions[_i, :N] == observed_labels[_i])
|
||||
return correct * 100.0 / observed_labels.size
|
||||
elif metric_name == "cross_entropy":
|
||||
return np.mean(cross_entropy(posteriors, np.eye(10)[observed_labels]))
|
||||
# Average accuracy per class - samples are groupped based on their true class label
|
||||
elif metric_name == "accuracy_per_class":
|
||||
vals = list()
|
||||
for _class in np.unique(true_labels, return_counts=False):
|
||||
mask = true_labels == _class
|
||||
val = np.mean(predictions[mask] == observed_labels[mask]) * 100.0
|
||||
vals.append(np.around(val, decimals=3))
|
||||
return vals
|
||||
else:
|
||||
raise ValueError("Unknown metric")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Execute benchmark 3')
|
||||
parser.add_argument('--config', dest='config', type=str, required=True, nargs='+',
|
||||
help='Path to config file(s) characterising trained CNN model/s')
|
||||
parser.add_argument('--curated-label-config', dest='curated_label_config', type=str, required=False, nargs='+',
|
||||
help='Path to config file(s) corresponding to curated labels in adjudication simulation')
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
create_folder(MODEL_SELECTION_BENCHMARK_DIR)
|
||||
create_logger(MODEL_SELECTION_BENCHMARK_DIR)
|
||||
main(args)
|
|
@ -0,0 +1,5 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче