added ability to score test images to disk for dutch F3 (#232)

* added ability to score test images to disk for dutch F3

* gitpython package fix from staging
This commit is contained in:
maxkazmsft 2020-03-19 13:49:09 -04:00 коммит произвёл GitHub
Родитель b75e6476c9
Коммит 4484155801
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 124 добавлений и 31 удалений

Просмотреть файл

@ -6,10 +6,7 @@ import torchvision
import logging import logging
import logging.config import logging.config
try: from tensorboardX import SummaryWriter
from tensorboardX import SummaryWriter
except ImportError:
raise RuntimeError("No tensorboardX package is found. Please install with the command: \npip install tensorboardX")
def create_summary_writer(log_dir): def create_summary_writer(log_dir):
@ -52,16 +49,22 @@ _DEFAULT_METRICS = {"accuracy": "Avg accuracy :", "nll": "Avg loss :"}
def log_metrics(summary_writer, train_engine, log_interval, engine, metrics_dict=_DEFAULT_METRICS): def log_metrics(summary_writer, train_engine, log_interval, engine, metrics_dict=_DEFAULT_METRICS):
metrics = engine.state.metrics metrics = engine.state.metrics
for m in metrics_dict: for m in metrics_dict:
summary_writer.add_scalar(metrics_dict[m], metrics[m], getattr(train_engine.state, log_interval)) summary_writer.add_scalar(
metrics_dict[m], metrics[m], getattr(train_engine.state, log_interval)
)
def create_image_writer(summary_writer, label, output_variable, normalize=False, transform_func=lambda x: x): def create_image_writer(
summary_writer, label, output_variable, normalize=False, transform_func=lambda x: x
):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def write_to(engine): def write_to(engine):
try: try:
data_tensor = transform_func(engine.state.output[output_variable]) data_tensor = transform_func(engine.state.output[output_variable])
image_grid = torchvision.utils.make_grid(data_tensor, normalize=normalize, scale_each=True) image_grid = torchvision.utils.make_grid(
data_tensor, normalize=normalize, scale_each=True
)
summary_writer.add_image(label, image_grid, engine.state.epoch) summary_writer.add_image(label, image_grid, engine.state.epoch)
except KeyError: except KeyError:
logger.warning("Predictions and or ground truth labels not available to report") logger.warning("Predictions and or ground truth labels not available to report")

Просмотреть файл

@ -26,7 +26,7 @@ dependencies:
- toolz==0.10.0 - toolz==0.10.0
- tabulate==0.8.2 - tabulate==0.8.2
- Jinja2==2.10.3 - Jinja2==2.10.3
- gitpython==3.0.5 - gitpython==3.0.6
- tensorboard==2.0.1 - tensorboard==2.0.1
- tensorboardx==1.9 - tensorboardx==1.9
- invoke==1.3.0 - invoke==1.3.0

Просмотреть файл

@ -10,8 +10,6 @@
""" """
Modified version of the Alaudah testing script Modified version of the Alaudah testing script
Runs only on single GPU Runs only on single GPU
Estimated time to run on single V100: 5 hours
""" """
import itertools import itertools
@ -25,9 +23,16 @@ import fire
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image
from albumentations import Compose, Normalize, PadIfNeeded, Resize from albumentations import Compose, Normalize, PadIfNeeded, Resize
from cv_lib.utils import load_log_configuration from cv_lib.utils import load_log_configuration
from cv_lib.segmentation import models from cv_lib.segmentation import models
from cv_lib.segmentation.dutchf3.utils import (
current_datetime,
generate_path,
git_branch,
git_hash,
)
from deepseismic_interpretation.dutchf3.data import ( from deepseismic_interpretation.dutchf3.data import (
add_patch_depth_channels, add_patch_depth_channels,
get_seismic_labels, get_seismic_labels,
@ -39,6 +44,8 @@ from toolz import compose, curry, itertoolz, pipe
from torch.utils import data from torch.utils import data
from toolz import take from toolz import take
from matplotlib import cm
_CLASS_NAMES = [ _CLASS_NAMES = [
"upper_ns", "upper_ns",
@ -57,9 +64,9 @@ class runningScore(object):
def _fast_hist(self, label_true, label_pred, n_class): def _fast_hist(self, label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class) mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2,).reshape( hist = np.bincount(
n_class, n_class n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2,
) ).reshape(n_class, n_class)
return hist return hist
def update(self, label_trues, label_preds): def update(self, label_trues, label_preds):
@ -99,6 +106,21 @@ class runningScore(object):
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
def normalize(array):
"""
Normalizes a segmentation mask array to be in [0,1] range
"""
min = array.min()
return (array - min) / (array.max() - min)
def mask_to_disk(mask, fname):
"""
write segmentation mask to disk using a particular colormap
"""
Image.fromarray(cm.gist_earth(normalize(mask), bytes=True)).save(fname)
def _transform_CHW_to_HWC(numpy_array): def _transform_CHW_to_HWC(numpy_array):
return np.moveaxis(numpy_array, 0, -1) return np.moveaxis(numpy_array, 0, -1)
@ -180,7 +202,9 @@ def _compose_processing_pipeline(depth, aug=None):
def _generate_batches(h, w, ps, patch_size, stride, batch_size=64): def _generate_batches(h, w, ps, patch_size, stride, batch_size=64):
hdc_wdx_generator = itertools.product(range(0, h - patch_size + ps, stride), range(0, w - patch_size + ps, stride),) hdc_wdx_generator = itertools.product(
range(0, h - patch_size + ps, stride), range(0, w - patch_size + ps, stride),
)
for batch_indexes in itertoolz.partition_all(batch_size, hdc_wdx_generator): for batch_indexes in itertoolz.partition_all(batch_size, hdc_wdx_generator):
yield batch_indexes yield batch_indexes
@ -191,7 +215,9 @@ def _output_processing_pipeline(config, output):
_, _, h, w = output.shape _, _, h, w = output.shape
if config.TEST.POST_PROCESSING.SIZE != h or config.TEST.POST_PROCESSING.SIZE != w: if config.TEST.POST_PROCESSING.SIZE != h or config.TEST.POST_PROCESSING.SIZE != w:
output = F.interpolate( output = F.interpolate(
output, size=(config.TEST.POST_PROCESSING.SIZE, config.TEST.POST_PROCESSING.SIZE,), mode="bilinear", output,
size=(config.TEST.POST_PROCESSING.SIZE, config.TEST.POST_PROCESSING.SIZE,),
mode="bilinear",
) )
if config.TEST.POST_PROCESSING.CROP_PIXELS > 0: if config.TEST.POST_PROCESSING.CROP_PIXELS > 0:
@ -206,7 +232,15 @@ def _output_processing_pipeline(config, output):
def _patch_label_2d( def _patch_label_2d(
model, img, pre_processing, output_processing, patch_size, stride, batch_size, device, num_classes, model,
img,
pre_processing,
output_processing,
patch_size,
stride,
batch_size,
device,
num_classes,
): ):
"""Processes a whole section """Processes a whole section
""" """
@ -221,14 +255,19 @@ def _patch_label_2d(
# generate output: # generate output:
for batch_indexes in _generate_batches(h, w, ps, patch_size, stride, batch_size=batch_size): for batch_indexes in _generate_batches(h, w, ps, patch_size, stride, batch_size=batch_size):
batch = torch.stack( batch = torch.stack(
[pipe(img_p, _extract_patch(hdx, wdx, ps, patch_size), pre_processing,) for hdx, wdx in batch_indexes], [
pipe(img_p, _extract_patch(hdx, wdx, ps, patch_size), pre_processing,)
for hdx, wdx in batch_indexes
],
dim=0, dim=0,
) )
model_output = model(batch.to(device)) model_output = model(batch.to(device))
for (hdx, wdx), output in zip(batch_indexes, model_output.detach().cpu()): for (hdx, wdx), output in zip(batch_indexes, model_output.detach().cpu()):
output = output_processing(output) output = output_processing(output)
output_p[:, :, hdx + ps : hdx + ps + patch_size, wdx + ps : wdx + ps + patch_size,] += output output_p[
:, :, hdx + ps : hdx + ps + patch_size, wdx + ps : wdx + ps + patch_size,
] += output
# crop the output_p in the middle # crop the output_p in the middle
output = output_p[:, :, ps:-ps, ps:-ps] output = output_p[:, :, ps:-ps, ps:-ps]
@ -253,12 +292,22 @@ def to_image(label_mask, n_classes=6):
def _evaluate_split( def _evaluate_split(
split, section_aug, model, pre_processing, output_processing, device, running_metrics_overall, config, debug=False split,
section_aug,
model,
pre_processing,
output_processing,
device,
running_metrics_overall,
config,
debug=False,
): ):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TestSectionLoader = get_test_loader(config) TestSectionLoader = get_test_loader(config)
test_set = TestSectionLoader(config.DATASET.ROOT, split=split, is_transform=True, augmentations=section_aug,) test_set = TestSectionLoader(
config.DATASET.ROOT, split=split, is_transform=True, augmentations=section_aug,
)
n_classes = test_set.n_classes n_classes = test_set.n_classes
@ -268,6 +317,21 @@ def _evaluate_split(
logger.info("Running in Debug/Test mode") logger.info("Running in Debug/Test mode")
test_loader = take(1, test_loader) test_loader = take(1, test_loader)
try:
output_dir = generate_path(
config.OUTPUT_DIR + "_test",
git_branch(),
git_hash(),
config.MODEL.NAME,
current_datetime(),
)
except TypeError:
output_dir = generate_path(
config.OUTPUT_DIR + "_test",
config.MODEL.NAME,
current_datetime(),
)
running_metrics_split = runningScore(n_classes) running_metrics_split = runningScore(n_classes)
# testing mode: # testing mode:
@ -295,6 +359,10 @@ def _evaluate_split(
running_metrics_split.update(gt, pred) running_metrics_split.update(gt, pred)
running_metrics_overall.update(gt, pred) running_metrics_overall.update(gt, pred)
# dump images to disk for review
mask_to_disk(pred.squeeze(), os.path.join(output_dir, f"{i}_pred.png"))
mask_to_disk(gt.squeeze(), os.path.join(output_dir, f"{i}_gt.png"))
# get scores # get scores
score, class_iou = running_metrics_split.get_scores() score, class_iou = running_metrics_split.get_scores()
@ -350,12 +418,16 @@ def test(*options, cfg=None, debug=False):
running_metrics_overall = runningScore(n_classes) running_metrics_overall = runningScore(n_classes)
# Augmentation # Augmentation
section_aug = Compose([Normalize(mean=(config.TRAIN.MEAN,), std=(config.TRAIN.STD,), max_pixel_value=1,)]) section_aug = Compose(
[Normalize(mean=(config.TRAIN.MEAN,), std=(config.TRAIN.STD,), max_pixel_value=1,)]
)
patch_aug = Compose( patch_aug = Compose(
[ [
Resize( Resize(
config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT, config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH, always_apply=True, config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT,
config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH,
always_apply=True,
), ),
PadIfNeeded( PadIfNeeded(
min_height=config.TRAIN.AUGMENTATIONS.PAD.HEIGHT, min_height=config.TRAIN.AUGMENTATIONS.PAD.HEIGHT,

Просмотреть файл

@ -111,7 +111,9 @@ def run(*options, cfg=None, debug=False):
[ [
Normalize(mean=(config.TRAIN.MEAN,), std=(config.TRAIN.STD,), max_pixel_value=1), Normalize(mean=(config.TRAIN.MEAN,), std=(config.TRAIN.STD,), max_pixel_value=1),
Resize( Resize(
config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT, config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH, always_apply=True, config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT,
config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH,
always_apply=True,
), ),
PadIfNeeded( PadIfNeeded(
min_height=config.TRAIN.AUGMENTATIONS.PAD.HEIGHT, min_height=config.TRAIN.AUGMENTATIONS.PAD.HEIGHT,
@ -151,9 +153,14 @@ def run(*options, cfg=None, debug=False):
n_classes = train_set.n_classes n_classes = train_set.n_classes
train_loader = data.DataLoader( train_loader = data.DataLoader(
train_set, batch_size=config.TRAIN.BATCH_SIZE_PER_GPU, num_workers=config.WORKERS, shuffle=True, train_set,
batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
num_workers=config.WORKERS,
shuffle=True,
)
val_loader = data.DataLoader(
val_set, batch_size=config.VALIDATION.BATCH_SIZE_PER_GPU, num_workers=config.WORKERS,
) )
val_loader = data.DataLoader(val_set, batch_size=config.VALIDATION.BATCH_SIZE_PER_GPU, num_workers=config.WORKERS,)
model = getattr(models, config.MODEL.NAME).get_seg_model(config) model = getattr(models, config.MODEL.NAME).get_seg_model(config)
@ -170,14 +177,18 @@ def run(*options, cfg=None, debug=False):
) )
try: try:
output_dir = generate_path(config.OUTPUT_DIR, git_branch(), git_hash(), config.MODEL.NAME, current_datetime(),) output_dir = generate_path(
config.OUTPUT_DIR, git_branch(), git_hash(), config.MODEL.NAME, current_datetime(),
)
except TypeError: except TypeError:
output_dir = generate_path(config.OUTPUT_DIR, config.MODEL.NAME, current_datetime(),) output_dir = generate_path(config.OUTPUT_DIR, config.MODEL.NAME, current_datetime(),)
summary_writer = create_summary_writer(log_dir=path.join(output_dir, config.LOG_DIR)) summary_writer = create_summary_writer(log_dir=path.join(output_dir, config.LOG_DIR))
snapshot_duration = scheduler_step * len(train_loader) snapshot_duration = scheduler_step * len(train_loader)
scheduler = CosineAnnealingScheduler(optimizer, "lr", config.TRAIN.MAX_LR, config.TRAIN.MIN_LR, snapshot_duration) scheduler = CosineAnnealingScheduler(
optimizer, "lr", config.TRAIN.MAX_LR, config.TRAIN.MIN_LR, snapshot_duration
)
# weights are inversely proportional to the frequency of the classes in the # weights are inversely proportional to the frequency of the classes in the
# training set # training set
@ -190,7 +201,8 @@ def run(*options, cfg=None, debug=False):
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
trainer.add_event_handler( trainer.add_event_handler(
Events.ITERATION_COMPLETED, logging_handlers.log_training_output(log_interval=config.PRINT_FREQ), Events.ITERATION_COMPLETED,
logging_handlers.log_training_output(log_interval=config.PRINT_FREQ),
) )
trainer.add_event_handler(Events.EPOCH_STARTED, logging_handlers.log_lr(optimizer)) trainer.add_event_handler(Events.EPOCH_STARTED, logging_handlers.log_lr(optimizer))
trainer.add_event_handler( trainer.add_event_handler(
@ -208,7 +220,9 @@ def run(*options, cfg=None, debug=False):
prepare_batch, prepare_batch,
metrics={ metrics={
"nll": Loss(criterion, output_transform=_select_pred_and_mask), "nll": Loss(criterion, output_transform=_select_pred_and_mask),
"pixacc": pixelwise_accuracy(n_classes, output_transform=_select_pred_and_mask, device=device), "pixacc": pixelwise_accuracy(
n_classes, output_transform=_select_pred_and_mask, device=device
),
"cacc": class_accuracy(n_classes, output_transform=_select_pred_and_mask), "cacc": class_accuracy(n_classes, output_transform=_select_pred_and_mask),
"mca": mean_class_accuracy(n_classes, output_transform=_select_pred_and_mask), "mca": mean_class_accuracy(n_classes, output_transform=_select_pred_and_mask),
"ciou": class_iou(n_classes, output_transform=_select_pred_and_mask), "ciou": class_iou(n_classes, output_transform=_select_pred_and_mask),
@ -267,11 +281,15 @@ def run(*options, cfg=None, debug=False):
) )
evaluator.add_event_handler( evaluator.add_event_handler(
Events.EPOCH_COMPLETED, Events.EPOCH_COMPLETED,
create_image_writer(summary_writer, "Validation/Mask", "mask", transform_func=transform_func), create_image_writer(
summary_writer, "Validation/Mask", "mask", transform_func=transform_func
),
) )
evaluator.add_event_handler( evaluator.add_event_handler(
Events.EPOCH_COMPLETED, Events.EPOCH_COMPLETED,
create_image_writer(summary_writer, "Validation/Pred", "y_pred", transform_func=transform_pred), create_image_writer(
summary_writer, "Validation/Pred", "y_pred", transform_func=transform_pred
),
) )
def snapshot_function(): def snapshot_function():