added ability to score test images to disk for dutch F3 (#232)
* added ability to score test images to disk for dutch F3 * gitpython package fix from staging
This commit is contained in:
Родитель
b75e6476c9
Коммит
4484155801
|
@ -6,10 +6,7 @@ import torchvision
|
||||||
import logging
|
import logging
|
||||||
import logging.config
|
import logging.config
|
||||||
|
|
||||||
try:
|
from tensorboardX import SummaryWriter
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
except ImportError:
|
|
||||||
raise RuntimeError("No tensorboardX package is found. Please install with the command: \npip install tensorboardX")
|
|
||||||
|
|
||||||
|
|
||||||
def create_summary_writer(log_dir):
|
def create_summary_writer(log_dir):
|
||||||
|
@ -52,16 +49,22 @@ _DEFAULT_METRICS = {"accuracy": "Avg accuracy :", "nll": "Avg loss :"}
|
||||||
def log_metrics(summary_writer, train_engine, log_interval, engine, metrics_dict=_DEFAULT_METRICS):
|
def log_metrics(summary_writer, train_engine, log_interval, engine, metrics_dict=_DEFAULT_METRICS):
|
||||||
metrics = engine.state.metrics
|
metrics = engine.state.metrics
|
||||||
for m in metrics_dict:
|
for m in metrics_dict:
|
||||||
summary_writer.add_scalar(metrics_dict[m], metrics[m], getattr(train_engine.state, log_interval))
|
summary_writer.add_scalar(
|
||||||
|
metrics_dict[m], metrics[m], getattr(train_engine.state, log_interval)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_image_writer(summary_writer, label, output_variable, normalize=False, transform_func=lambda x: x):
|
def create_image_writer(
|
||||||
|
summary_writer, label, output_variable, normalize=False, transform_func=lambda x: x
|
||||||
|
):
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def write_to(engine):
|
def write_to(engine):
|
||||||
try:
|
try:
|
||||||
data_tensor = transform_func(engine.state.output[output_variable])
|
data_tensor = transform_func(engine.state.output[output_variable])
|
||||||
image_grid = torchvision.utils.make_grid(data_tensor, normalize=normalize, scale_each=True)
|
image_grid = torchvision.utils.make_grid(
|
||||||
|
data_tensor, normalize=normalize, scale_each=True
|
||||||
|
)
|
||||||
summary_writer.add_image(label, image_grid, engine.state.epoch)
|
summary_writer.add_image(label, image_grid, engine.state.epoch)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.warning("Predictions and or ground truth labels not available to report")
|
logger.warning("Predictions and or ground truth labels not available to report")
|
||||||
|
|
|
@ -26,7 +26,7 @@ dependencies:
|
||||||
- toolz==0.10.0
|
- toolz==0.10.0
|
||||||
- tabulate==0.8.2
|
- tabulate==0.8.2
|
||||||
- Jinja2==2.10.3
|
- Jinja2==2.10.3
|
||||||
- gitpython==3.0.5
|
- gitpython==3.0.6
|
||||||
- tensorboard==2.0.1
|
- tensorboard==2.0.1
|
||||||
- tensorboardx==1.9
|
- tensorboardx==1.9
|
||||||
- invoke==1.3.0
|
- invoke==1.3.0
|
||||||
|
|
|
@ -10,8 +10,6 @@
|
||||||
"""
|
"""
|
||||||
Modified version of the Alaudah testing script
|
Modified version of the Alaudah testing script
|
||||||
Runs only on single GPU
|
Runs only on single GPU
|
||||||
|
|
||||||
Estimated time to run on single V100: 5 hours
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -25,9 +23,16 @@ import fire
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
from albumentations import Compose, Normalize, PadIfNeeded, Resize
|
from albumentations import Compose, Normalize, PadIfNeeded, Resize
|
||||||
from cv_lib.utils import load_log_configuration
|
from cv_lib.utils import load_log_configuration
|
||||||
from cv_lib.segmentation import models
|
from cv_lib.segmentation import models
|
||||||
|
from cv_lib.segmentation.dutchf3.utils import (
|
||||||
|
current_datetime,
|
||||||
|
generate_path,
|
||||||
|
git_branch,
|
||||||
|
git_hash,
|
||||||
|
)
|
||||||
from deepseismic_interpretation.dutchf3.data import (
|
from deepseismic_interpretation.dutchf3.data import (
|
||||||
add_patch_depth_channels,
|
add_patch_depth_channels,
|
||||||
get_seismic_labels,
|
get_seismic_labels,
|
||||||
|
@ -39,6 +44,8 @@ from toolz import compose, curry, itertoolz, pipe
|
||||||
from torch.utils import data
|
from torch.utils import data
|
||||||
from toolz import take
|
from toolz import take
|
||||||
|
|
||||||
|
from matplotlib import cm
|
||||||
|
|
||||||
|
|
||||||
_CLASS_NAMES = [
|
_CLASS_NAMES = [
|
||||||
"upper_ns",
|
"upper_ns",
|
||||||
|
@ -57,9 +64,9 @@ class runningScore(object):
|
||||||
|
|
||||||
def _fast_hist(self, label_true, label_pred, n_class):
|
def _fast_hist(self, label_true, label_pred, n_class):
|
||||||
mask = (label_true >= 0) & (label_true < n_class)
|
mask = (label_true >= 0) & (label_true < n_class)
|
||||||
hist = np.bincount(n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2,).reshape(
|
hist = np.bincount(
|
||||||
n_class, n_class
|
n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2,
|
||||||
)
|
).reshape(n_class, n_class)
|
||||||
return hist
|
return hist
|
||||||
|
|
||||||
def update(self, label_trues, label_preds):
|
def update(self, label_trues, label_preds):
|
||||||
|
@ -99,6 +106,21 @@ class runningScore(object):
|
||||||
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
|
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(array):
|
||||||
|
"""
|
||||||
|
Normalizes a segmentation mask array to be in [0,1] range
|
||||||
|
"""
|
||||||
|
min = array.min()
|
||||||
|
return (array - min) / (array.max() - min)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_to_disk(mask, fname):
|
||||||
|
"""
|
||||||
|
write segmentation mask to disk using a particular colormap
|
||||||
|
"""
|
||||||
|
Image.fromarray(cm.gist_earth(normalize(mask), bytes=True)).save(fname)
|
||||||
|
|
||||||
|
|
||||||
def _transform_CHW_to_HWC(numpy_array):
|
def _transform_CHW_to_HWC(numpy_array):
|
||||||
return np.moveaxis(numpy_array, 0, -1)
|
return np.moveaxis(numpy_array, 0, -1)
|
||||||
|
|
||||||
|
@ -180,7 +202,9 @@ def _compose_processing_pipeline(depth, aug=None):
|
||||||
|
|
||||||
|
|
||||||
def _generate_batches(h, w, ps, patch_size, stride, batch_size=64):
|
def _generate_batches(h, w, ps, patch_size, stride, batch_size=64):
|
||||||
hdc_wdx_generator = itertools.product(range(0, h - patch_size + ps, stride), range(0, w - patch_size + ps, stride),)
|
hdc_wdx_generator = itertools.product(
|
||||||
|
range(0, h - patch_size + ps, stride), range(0, w - patch_size + ps, stride),
|
||||||
|
)
|
||||||
for batch_indexes in itertoolz.partition_all(batch_size, hdc_wdx_generator):
|
for batch_indexes in itertoolz.partition_all(batch_size, hdc_wdx_generator):
|
||||||
yield batch_indexes
|
yield batch_indexes
|
||||||
|
|
||||||
|
@ -191,7 +215,9 @@ def _output_processing_pipeline(config, output):
|
||||||
_, _, h, w = output.shape
|
_, _, h, w = output.shape
|
||||||
if config.TEST.POST_PROCESSING.SIZE != h or config.TEST.POST_PROCESSING.SIZE != w:
|
if config.TEST.POST_PROCESSING.SIZE != h or config.TEST.POST_PROCESSING.SIZE != w:
|
||||||
output = F.interpolate(
|
output = F.interpolate(
|
||||||
output, size=(config.TEST.POST_PROCESSING.SIZE, config.TEST.POST_PROCESSING.SIZE,), mode="bilinear",
|
output,
|
||||||
|
size=(config.TEST.POST_PROCESSING.SIZE, config.TEST.POST_PROCESSING.SIZE,),
|
||||||
|
mode="bilinear",
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.TEST.POST_PROCESSING.CROP_PIXELS > 0:
|
if config.TEST.POST_PROCESSING.CROP_PIXELS > 0:
|
||||||
|
@ -206,7 +232,15 @@ def _output_processing_pipeline(config, output):
|
||||||
|
|
||||||
|
|
||||||
def _patch_label_2d(
|
def _patch_label_2d(
|
||||||
model, img, pre_processing, output_processing, patch_size, stride, batch_size, device, num_classes,
|
model,
|
||||||
|
img,
|
||||||
|
pre_processing,
|
||||||
|
output_processing,
|
||||||
|
patch_size,
|
||||||
|
stride,
|
||||||
|
batch_size,
|
||||||
|
device,
|
||||||
|
num_classes,
|
||||||
):
|
):
|
||||||
"""Processes a whole section
|
"""Processes a whole section
|
||||||
"""
|
"""
|
||||||
|
@ -221,14 +255,19 @@ def _patch_label_2d(
|
||||||
# generate output:
|
# generate output:
|
||||||
for batch_indexes in _generate_batches(h, w, ps, patch_size, stride, batch_size=batch_size):
|
for batch_indexes in _generate_batches(h, w, ps, patch_size, stride, batch_size=batch_size):
|
||||||
batch = torch.stack(
|
batch = torch.stack(
|
||||||
[pipe(img_p, _extract_patch(hdx, wdx, ps, patch_size), pre_processing,) for hdx, wdx in batch_indexes],
|
[
|
||||||
|
pipe(img_p, _extract_patch(hdx, wdx, ps, patch_size), pre_processing,)
|
||||||
|
for hdx, wdx in batch_indexes
|
||||||
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_output = model(batch.to(device))
|
model_output = model(batch.to(device))
|
||||||
for (hdx, wdx), output in zip(batch_indexes, model_output.detach().cpu()):
|
for (hdx, wdx), output in zip(batch_indexes, model_output.detach().cpu()):
|
||||||
output = output_processing(output)
|
output = output_processing(output)
|
||||||
output_p[:, :, hdx + ps : hdx + ps + patch_size, wdx + ps : wdx + ps + patch_size,] += output
|
output_p[
|
||||||
|
:, :, hdx + ps : hdx + ps + patch_size, wdx + ps : wdx + ps + patch_size,
|
||||||
|
] += output
|
||||||
|
|
||||||
# crop the output_p in the middle
|
# crop the output_p in the middle
|
||||||
output = output_p[:, :, ps:-ps, ps:-ps]
|
output = output_p[:, :, ps:-ps, ps:-ps]
|
||||||
|
@ -253,12 +292,22 @@ def to_image(label_mask, n_classes=6):
|
||||||
|
|
||||||
|
|
||||||
def _evaluate_split(
|
def _evaluate_split(
|
||||||
split, section_aug, model, pre_processing, output_processing, device, running_metrics_overall, config, debug=False
|
split,
|
||||||
|
section_aug,
|
||||||
|
model,
|
||||||
|
pre_processing,
|
||||||
|
output_processing,
|
||||||
|
device,
|
||||||
|
running_metrics_overall,
|
||||||
|
config,
|
||||||
|
debug=False,
|
||||||
):
|
):
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TestSectionLoader = get_test_loader(config)
|
TestSectionLoader = get_test_loader(config)
|
||||||
test_set = TestSectionLoader(config.DATASET.ROOT, split=split, is_transform=True, augmentations=section_aug,)
|
test_set = TestSectionLoader(
|
||||||
|
config.DATASET.ROOT, split=split, is_transform=True, augmentations=section_aug,
|
||||||
|
)
|
||||||
|
|
||||||
n_classes = test_set.n_classes
|
n_classes = test_set.n_classes
|
||||||
|
|
||||||
|
@ -268,6 +317,21 @@ def _evaluate_split(
|
||||||
logger.info("Running in Debug/Test mode")
|
logger.info("Running in Debug/Test mode")
|
||||||
test_loader = take(1, test_loader)
|
test_loader = take(1, test_loader)
|
||||||
|
|
||||||
|
try:
|
||||||
|
output_dir = generate_path(
|
||||||
|
config.OUTPUT_DIR + "_test",
|
||||||
|
git_branch(),
|
||||||
|
git_hash(),
|
||||||
|
config.MODEL.NAME,
|
||||||
|
current_datetime(),
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
output_dir = generate_path(
|
||||||
|
config.OUTPUT_DIR + "_test",
|
||||||
|
config.MODEL.NAME,
|
||||||
|
current_datetime(),
|
||||||
|
)
|
||||||
|
|
||||||
running_metrics_split = runningScore(n_classes)
|
running_metrics_split = runningScore(n_classes)
|
||||||
|
|
||||||
# testing mode:
|
# testing mode:
|
||||||
|
@ -295,6 +359,10 @@ def _evaluate_split(
|
||||||
running_metrics_split.update(gt, pred)
|
running_metrics_split.update(gt, pred)
|
||||||
running_metrics_overall.update(gt, pred)
|
running_metrics_overall.update(gt, pred)
|
||||||
|
|
||||||
|
# dump images to disk for review
|
||||||
|
mask_to_disk(pred.squeeze(), os.path.join(output_dir, f"{i}_pred.png"))
|
||||||
|
mask_to_disk(gt.squeeze(), os.path.join(output_dir, f"{i}_gt.png"))
|
||||||
|
|
||||||
# get scores
|
# get scores
|
||||||
score, class_iou = running_metrics_split.get_scores()
|
score, class_iou = running_metrics_split.get_scores()
|
||||||
|
|
||||||
|
@ -350,12 +418,16 @@ def test(*options, cfg=None, debug=False):
|
||||||
running_metrics_overall = runningScore(n_classes)
|
running_metrics_overall = runningScore(n_classes)
|
||||||
|
|
||||||
# Augmentation
|
# Augmentation
|
||||||
section_aug = Compose([Normalize(mean=(config.TRAIN.MEAN,), std=(config.TRAIN.STD,), max_pixel_value=1,)])
|
section_aug = Compose(
|
||||||
|
[Normalize(mean=(config.TRAIN.MEAN,), std=(config.TRAIN.STD,), max_pixel_value=1,)]
|
||||||
|
)
|
||||||
|
|
||||||
patch_aug = Compose(
|
patch_aug = Compose(
|
||||||
[
|
[
|
||||||
Resize(
|
Resize(
|
||||||
config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT, config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH, always_apply=True,
|
config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT,
|
||||||
|
config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH,
|
||||||
|
always_apply=True,
|
||||||
),
|
),
|
||||||
PadIfNeeded(
|
PadIfNeeded(
|
||||||
min_height=config.TRAIN.AUGMENTATIONS.PAD.HEIGHT,
|
min_height=config.TRAIN.AUGMENTATIONS.PAD.HEIGHT,
|
||||||
|
|
|
@ -111,7 +111,9 @@ def run(*options, cfg=None, debug=False):
|
||||||
[
|
[
|
||||||
Normalize(mean=(config.TRAIN.MEAN,), std=(config.TRAIN.STD,), max_pixel_value=1),
|
Normalize(mean=(config.TRAIN.MEAN,), std=(config.TRAIN.STD,), max_pixel_value=1),
|
||||||
Resize(
|
Resize(
|
||||||
config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT, config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH, always_apply=True,
|
config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT,
|
||||||
|
config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH,
|
||||||
|
always_apply=True,
|
||||||
),
|
),
|
||||||
PadIfNeeded(
|
PadIfNeeded(
|
||||||
min_height=config.TRAIN.AUGMENTATIONS.PAD.HEIGHT,
|
min_height=config.TRAIN.AUGMENTATIONS.PAD.HEIGHT,
|
||||||
|
@ -151,9 +153,14 @@ def run(*options, cfg=None, debug=False):
|
||||||
n_classes = train_set.n_classes
|
n_classes = train_set.n_classes
|
||||||
|
|
||||||
train_loader = data.DataLoader(
|
train_loader = data.DataLoader(
|
||||||
train_set, batch_size=config.TRAIN.BATCH_SIZE_PER_GPU, num_workers=config.WORKERS, shuffle=True,
|
train_set,
|
||||||
|
batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
|
||||||
|
num_workers=config.WORKERS,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
val_loader = data.DataLoader(
|
||||||
|
val_set, batch_size=config.VALIDATION.BATCH_SIZE_PER_GPU, num_workers=config.WORKERS,
|
||||||
)
|
)
|
||||||
val_loader = data.DataLoader(val_set, batch_size=config.VALIDATION.BATCH_SIZE_PER_GPU, num_workers=config.WORKERS,)
|
|
||||||
|
|
||||||
model = getattr(models, config.MODEL.NAME).get_seg_model(config)
|
model = getattr(models, config.MODEL.NAME).get_seg_model(config)
|
||||||
|
|
||||||
|
@ -170,14 +177,18 @@ def run(*options, cfg=None, debug=False):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output_dir = generate_path(config.OUTPUT_DIR, git_branch(), git_hash(), config.MODEL.NAME, current_datetime(),)
|
output_dir = generate_path(
|
||||||
|
config.OUTPUT_DIR, git_branch(), git_hash(), config.MODEL.NAME, current_datetime(),
|
||||||
|
)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
output_dir = generate_path(config.OUTPUT_DIR, config.MODEL.NAME, current_datetime(),)
|
output_dir = generate_path(config.OUTPUT_DIR, config.MODEL.NAME, current_datetime(),)
|
||||||
|
|
||||||
summary_writer = create_summary_writer(log_dir=path.join(output_dir, config.LOG_DIR))
|
summary_writer = create_summary_writer(log_dir=path.join(output_dir, config.LOG_DIR))
|
||||||
|
|
||||||
snapshot_duration = scheduler_step * len(train_loader)
|
snapshot_duration = scheduler_step * len(train_loader)
|
||||||
scheduler = CosineAnnealingScheduler(optimizer, "lr", config.TRAIN.MAX_LR, config.TRAIN.MIN_LR, snapshot_duration)
|
scheduler = CosineAnnealingScheduler(
|
||||||
|
optimizer, "lr", config.TRAIN.MAX_LR, config.TRAIN.MIN_LR, snapshot_duration
|
||||||
|
)
|
||||||
|
|
||||||
# weights are inversely proportional to the frequency of the classes in the
|
# weights are inversely proportional to the frequency of the classes in the
|
||||||
# training set
|
# training set
|
||||||
|
@ -190,7 +201,8 @@ def run(*options, cfg=None, debug=False):
|
||||||
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
|
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
|
||||||
|
|
||||||
trainer.add_event_handler(
|
trainer.add_event_handler(
|
||||||
Events.ITERATION_COMPLETED, logging_handlers.log_training_output(log_interval=config.PRINT_FREQ),
|
Events.ITERATION_COMPLETED,
|
||||||
|
logging_handlers.log_training_output(log_interval=config.PRINT_FREQ),
|
||||||
)
|
)
|
||||||
trainer.add_event_handler(Events.EPOCH_STARTED, logging_handlers.log_lr(optimizer))
|
trainer.add_event_handler(Events.EPOCH_STARTED, logging_handlers.log_lr(optimizer))
|
||||||
trainer.add_event_handler(
|
trainer.add_event_handler(
|
||||||
|
@ -208,7 +220,9 @@ def run(*options, cfg=None, debug=False):
|
||||||
prepare_batch,
|
prepare_batch,
|
||||||
metrics={
|
metrics={
|
||||||
"nll": Loss(criterion, output_transform=_select_pred_and_mask),
|
"nll": Loss(criterion, output_transform=_select_pred_and_mask),
|
||||||
"pixacc": pixelwise_accuracy(n_classes, output_transform=_select_pred_and_mask, device=device),
|
"pixacc": pixelwise_accuracy(
|
||||||
|
n_classes, output_transform=_select_pred_and_mask, device=device
|
||||||
|
),
|
||||||
"cacc": class_accuracy(n_classes, output_transform=_select_pred_and_mask),
|
"cacc": class_accuracy(n_classes, output_transform=_select_pred_and_mask),
|
||||||
"mca": mean_class_accuracy(n_classes, output_transform=_select_pred_and_mask),
|
"mca": mean_class_accuracy(n_classes, output_transform=_select_pred_and_mask),
|
||||||
"ciou": class_iou(n_classes, output_transform=_select_pred_and_mask),
|
"ciou": class_iou(n_classes, output_transform=_select_pred_and_mask),
|
||||||
|
@ -267,11 +281,15 @@ def run(*options, cfg=None, debug=False):
|
||||||
)
|
)
|
||||||
evaluator.add_event_handler(
|
evaluator.add_event_handler(
|
||||||
Events.EPOCH_COMPLETED,
|
Events.EPOCH_COMPLETED,
|
||||||
create_image_writer(summary_writer, "Validation/Mask", "mask", transform_func=transform_func),
|
create_image_writer(
|
||||||
|
summary_writer, "Validation/Mask", "mask", transform_func=transform_func
|
||||||
|
),
|
||||||
)
|
)
|
||||||
evaluator.add_event_handler(
|
evaluator.add_event_handler(
|
||||||
Events.EPOCH_COMPLETED,
|
Events.EPOCH_COMPLETED,
|
||||||
create_image_writer(summary_writer, "Validation/Pred", "y_pred", transform_func=transform_pred),
|
create_image_writer(
|
||||||
|
summary_writer, "Validation/Pred", "y_pred", transform_func=transform_pred
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def snapshot_function():
|
def snapshot_function():
|
||||||
|
|
Загрузка…
Ссылка в новой задаче