Small updates to image segmentation
This commit is contained in:
PatrickBue 2020-06-15 20:04:27 +00:00 коммит произвёл GitHub
Родитель 28f474bed0
Коммит 6d76141204
5 изменённых файлов: 73 добавлений и 72 удалений

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -50,7 +50,7 @@ from utils_cv.segmentation.data import Urls as seg_urls
from utils_cv.segmentation.dataset import load_im, load_mask
from utils_cv.segmentation.model import (
confusion_matrix,
get_objective_fct,
get_ratio_correct_metric,
predict,
)
from utils_cv.similarity.data import Urls as is_urls
@ -970,7 +970,7 @@ def seg_learner(tiny_seg_databunch, seg_classes):
tiny_seg_databunch,
models.resnet18,
wd=1e-2,
metrics=get_objective_fct(seg_classes),
metrics=get_ratio_correct_metric(seg_classes),
)

Просмотреть файл

@ -4,15 +4,15 @@ import functools
import numpy as np
from utils_cv.segmentation.model import (
get_objective_fct,
get_ratio_correct_metric,
predict,
confusion_matrix,
print_accuracies,
)
def test_get_objective_fct(seg_classes):
fct = get_objective_fct(seg_classes)
def test_get_ratio_correct_metric(seg_classes):
fct = get_ratio_correct_metric(seg_classes)
assert type(fct) == functools.partial

Просмотреть файл

@ -21,6 +21,7 @@ from fastai.vision import (
imagenet_stats,
Learner,
models,
ResizeMethod,
SegmentationItemList,
unet_learner,
)
@ -30,7 +31,7 @@ import pandas as pd
from utils_cv.common.gpu import db_num_workers
from utils_cv.segmentation.dataset import read_classes
from utils_cv.segmentation.model import get_objective_fct
from utils_cv.segmentation.model import get_ratio_correct_metric
Time = float
@ -294,7 +295,7 @@ class ParameterSweeper:
SegmentationItemList.from_folder(im_path)
.split_by_rand_pct(valid_pct=0.33)
.label_from_func(get_gt_filename, classes=classes)
.transform(tfms=tfms, size=im_size, tfm_y=True)
.transform(tfms=tfms, resize_method = ResizeMethod.CROP, size=im_size, tfm_y=True)
.databunch(bs=bs, num_workers=db_num_workers())
.normalize(imagenet_stats)
)
@ -412,7 +413,7 @@ class ParameterSweeper:
elif learner_type == "unet":
classes = read_classes(os.path.join(data_path, "classes.txt"))
data = self._get_data_bunch_segmentationitemlist(data_path, transform, im_size, batch_size, classes)
metric = get_objective_fct(classes)
metric = get_ratio_correct_metric(classes)
metric.__name__ = "ratio_correct"
learn = unet_learner(
data,

Просмотреть файл

@ -14,7 +14,7 @@ from .dataset import load_im
# Ignore pixels marked as void. That could be pixels which are hard to annotate and hence should not influence training.
def _objective_fct_partial(void_id, input, target):
def ratio_correct(void_id, input, target):
""" Helper function to compute the ratio of correctly classified pixels. """
target = target.squeeze(1)
if void_id:
@ -28,8 +28,8 @@ def _objective_fct_partial(void_id, input, target):
return ratio_correct
def get_objective_fct(classes: List[str]):
""" Returns objective function for model training, defined as ratio of correctly classified pixels.
def get_ratio_correct_metric(classes: List[str]):
""" Returns metric which computes the ratio of correctly classified pixels.
Args:
classes: list of class names
@ -43,7 +43,7 @@ def get_objective_fct(classes: List[str]):
else:
void_id = None
return partial(_objective_fct_partial, void_id)
return partial(ratio_correct, void_id)
def predict(