adding support for void label to unet segmentation

This commit is contained in:
PatrickBue 2021-08-18 12:52:11 -04:00
Родитель c4485f26aa
Коммит eb4cf083ec
3 изменённых файлов: 125 добавлений и 79 удалений

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -17,7 +17,7 @@ from .dataset import load_im
def ratio_correct(void_id, input, target):
""" Helper function to compute the ratio of correctly classified pixels. """
target = target.squeeze(1)
if void_id:
if void_id != None:
mask = target != void_id
ratio_correct = (
(input.argmax(dim=1)[mask] == target[mask]).float().mean()

Просмотреть файл

@ -57,6 +57,7 @@ def plot_segmentation(
show: bool = True,
figsize: Tuple[int, int] = (16, 4),
cmap: ListedColormap = cm.get_cmap("Set3"),
ignore_background_label = True
) -> None:
""" Plot an image, its predicted mask with associated scores, and optionally the ground truth mask.
@ -68,10 +69,15 @@ def plot_segmentation(
show: set to true to call matplotlib's show()
figsize: figure size
cmap: mask color map.
ignore_background_label: set to True to ignore the 0 label.
"""
im = load_im(im_or_path)
pred_mask = pil2tensor(pred_mask, np.float32)
max_scores = np.max(np.array(pred_scores[1:]), axis=0)
if ignore_background_label:
start_label = 1
else:
start_label = 0
max_scores = np.max(np.array(pred_scores[start_label:]), axis=0)
max_scores = pil2tensor(max_scores, np.float32)
# Plot groud truth mask if provided