Adding image segmentation 01 and 11 notebooks

This commit is contained in:
PatrickBue 2020-06-10 00:36:49 +00:00 коммит произвёл GitHub
Родитель 3dd9914bff
Коммит 97187bc341
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
23 изменённых файлов: 3024 добавлений и 648 удалений

Просмотреть файл

@ -2,7 +2,7 @@
```diff
+ March 27: Released v1.1 with new and improved
+ functionality for image retrieval, object detection,
+ functionality for image retrieval, object detection,
+ keypoint detection and action recognition.
+ For additional details, please refer to our releases page.
```
@ -41,7 +41,7 @@ instructions on how to setup the compute environment and dependencies needed to
notebooks in this repo. Once your environment is setup, navigate to the
[Scenarios](scenarios) folder and start exploring the notebooks.
Alternatively, we support Binder
Alternatively, we support Binder
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/PatrickBue/computervision-recipes/master?filepath=scenarios%2Fclassification%2F01_training_introduction_BINDER.ipynb)
which makes it easy to try one of our notebooks in a web-browser simply by following this link. However, Binder is free, and as as result only comes with limited CPU compute power and without GPU support. Expect the notebook to run very slowly (this is somewhat improved by reducing image resolution to e.g. 60 pixels but at the cost of low accuracies).
@ -51,10 +51,11 @@ The following is a summary of commonly used Computer Vision scenarios that are c
| Scenario | Support | Description |
| -------- | ----------- | ----------- |
| [Classification](scenarios/classification) | Base | Image Classification is a supervised machine learning technique that allows you to learn and predict the category of a given image. |
| [Classification](scenarios/classification) | Base | Image Classification is a supervised machine learning technique to learn and predict the category of a given image. |
| [Similarity](scenarios/similarity) | Base | Image Similarity is a way to compute a similarity score given a pair of images. Given an image, it allows you to identify the most similar image in a given dataset. |
| [Detection](scenarios/detection) | Base | Object Detection is a technique that allows you to detect the bounding box of an object within an image. |
| [Keypoints](scenarios/keypoints) | Base | Keypoint detection can be used to detect specific points on an object. A pre-trained model is provided to detect body joints for human pose estimation. |
| [Segmentation](scenarios/segmentation) | Base | Image Segmentation assigns a category to each pixel in an image. |
| [Action recognition](contrib/action_recognition) | Contrib | Action recognition to identify in video/webcam footage what actions are performed (e.g. "running", "opening a bottle") and at what respective start/end times.|
| [Crowd counting](contrib/crowd_counting) | Contrib | Counting the number of people in low-crowd-density (e.g. less than 10 people) and high-crowd-density (e.g. thousands of people) scenarios.|
@ -67,7 +68,7 @@ Note that for certain computer vision problems, you may not need to build your o
The following Microsoft services offer simple solutions to address common computer vision tasks:
- [Vision Services](https://docs.microsoft.com/en-us/azure/cognitive-services/computer-vision/)
are a set of pre-trained REST APIs which can be called for image tagging, face recognition, OCR, video analytics, and more. These APIs work out of the box and require minimal expertise in machine learning, but have limited customization capabilities. See the various demos available to get a feel for the functionality (e.g. [Computer Vision](https://azure.microsoft.com/en-us/services/cognitive-services/computer-vision/#analyze)). The service can be used through API calls or through SDKs (available in .NET, Python, Java, Node and Go languages)
are a set of pre-trained REST APIs which can be called for image tagging, face recognition, OCR, video analytics, and more. These APIs work out of the box and require minimal expertise in machine learning, but have limited customization capabilities. See the various demos available to get a feel for the functionality (e.g. [Computer Vision](https://azure.microsoft.com/en-us/services/cognitive-services/computer-vision/#analyze)). The service can be used through API calls or through SDKs (available in .NET, Python, Java, Node and Go languages)
- [Custom Vision](https://docs.microsoft.com/en-us/azure/cognitive-services/custom-vision-service/home)
is a SaaS service to train and deploy a model as a REST API given a user-provided training set. All steps including image upload, annotation, and model deployment can be performed using an intuitive UI or through SDKs (available in .NEt, Python, Java, Node and Go languages). Training image classification or object detection models can be achieved with minimal machine learning expertise. The Custom Vision offers more flexibility than using the pre-trained cognitive services APIs, but requires the user to bring and annotate their own data.

Просмотреть файл

@ -2,10 +2,13 @@
| Scenario | Description |
| -------- | ----------- |
| [Classification](classification) | Image Classification is a supervised machine learning technique that allows you to learn and predict the category of a given image. |
| [Classification](classification) | Image Classification is a supervised machine learning technique to learn and predict the category of a given image. |
| [Similarity](similarity) | Image Similarity is a way to compute a similarity score given a pair of images. Given an image, it allows you to identify the most similar image in a given dataset. |
| [Detection](detection) | Object Detection is a technique that allows you to detect the bounding box of an object within an image. |
| [Keypoints](keypoints) | Keypoint detection can be used to detect specific points on an object. A pre-trained model is provided to detect body joints for human pose estimation. |
| [Keypoints](keypoints) | Keypoint Detection can be used to detect specific points on an object. A pre-trained model is provided to detect body joints for human pose estimation. |
| [Segmentation](segmentation) | Image Segmentation assigns a category to each pixel in an image. |
| [Action Recognition](action_recognition) | Action Recognition (also known as activity recognition) consists of classifying various actions from a sequence of frames, such as "reading" or "drinking". |
# Scenarios

Двоичные данные
scenarios/media/cv_overview.jpg

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 105 KiB

После

Ширина:  |  Высота:  |  Размер: 118 KiB

Двоичные данные
scenarios/media/figures.pptx

Двоичный файл не отображается.

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -0,0 +1,25 @@
# Image segmentation
This directory provides examples and best practices for building image segmentation systems. Our goal is to enable the users to bring their own datasets and train a high-accuracy model easily and quickly.
| Image segmentation example |
|--|
|<img align="center" src="./media/imseg_example.jpg" height="300"/>|
Our implementation uses fastai's [UNet](https://docs.fast.ai/vision.models.unet.html) model, where the CNN backbone (e.g. ResNet) is pre-trained on ImageNet and hence can be fine-tuned with only small amounts of annotated training examples. A good understanding of [image classification](../classification) concepts, while not necessary, is strongly recommended.
## Notebooks
The following notebooks are provided:
| Notebook name | Description |
| --- | --- |
| [01_training_introduction.ipynb](./01_training_introduction.ipynb)| Notebook to train and evaluate an image segmentation model.|
| [11_exploring_hyperparameters.ipynb](11_exploring_hyperparameters.ipynb)| Finds optimal model parameters using grid search. |
## Contribution guidelines
See the [contribution guidelines](../../CONTRIBUTING.md) in the root folder.

Двоичные данные
scenarios/segmentation/media/imseg_example.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 57 KiB

Двоичные данные
scenarios/segmentation/media/param_sweep.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 39 KiB

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -16,7 +16,15 @@ import random
from PIL import Image
from torch import tensor
from pathlib import Path
from fastai.vision import cnn_learner, DatasetType, models
from fastai.vision import (
cnn_learner,
unet_learner,
DatasetType,
get_image_files,
get_transforms,
models,
SegmentationItemList,
)
from fastai.vision.data import ImageList, imagenet_stats
from typing import List, Tuple
from tempfile import TemporaryDirectory
@ -36,6 +44,13 @@ from utils_cv.detection.model import (
_extract_od_results,
_apply_threshold,
)
from utils_cv.segmentation.data import Urls as seg_urls
from utils_cv.segmentation.dataset import load_im, load_mask
from utils_cv.segmentation.model import (
confusion_matrix,
get_objective_fct,
predict,
)
from utils_cv.similarity.data import Urls as is_urls
from utils_cv.similarity.model import compute_features_learner
@ -85,6 +100,18 @@ def path_action_recognition_notebooks():
)
def path_segmentation_notebooks():
""" Returns the path of the similarity notebooks folder. """
return os.path.abspath(
os.path.join(
os.path.dirname(__file__),
os.path.pardir,
"scenarios",
"segmentation",
)
)
# ----- Module fixtures ----------------------------------------------------------
@ -156,7 +183,9 @@ def detection_notebooks():
"01": os.path.join(folder_notebooks, "01_training_introduction.ipynb"),
"02": os.path.join(folder_notebooks, "02_mask_rcnn.ipynb"),
"03": os.path.join(folder_notebooks, "03_keypoint_rcnn.ipynb"),
"04": os.path.join(folder_notebooks, "04_coco_accuracy_vs_speed.ipynb"),
"04": os.path.join(
folder_notebooks, "04_coco_accuracy_vs_speed.ipynb"
),
"11": os.path.join(
folder_notebooks, "11_exploring_hyperparameters_on_azureml.ipynb"
),
@ -184,6 +213,18 @@ def action_recognition_notebooks():
return paths
@pytest.fixture(scope="module")
def segmentation_notebooks():
folder_notebooks = path_segmentation_notebooks()
# Path for the notebooks
paths = {
"01": os.path.join(folder_notebooks, "01_training_introduction.ipynb"),
"11": os.path.join(folder_notebooks, "11_exploring_hyperparameters.ipynb"),
}
return paths
# ----- Function fixtures ----------------------------------------------------------
@ -723,6 +764,7 @@ def ar_path(tmp_session) -> str:
# ----- AML Settings ----------------------------------------------------------
@pytest.fixture(scope="session")
def coco_sample_path(tmpdir_factory) -> str:
""" Returns the path to a coco-formatted annotation. """
@ -794,6 +836,7 @@ def tiny_is_data_path(tmp_session) -> str:
@pytest.fixture(scope="session")
def tiny_ic_databunch_valid_features(tiny_ic_databunch):
""" Returns DNN features for the tiny fridge objects dataset. """
learn = cnn_learner(tiny_ic_databunch, models.resnet18)
embedding_layer = learn.model[1][6]
features = compute_features_learner(
@ -801,3 +844,87 @@ def tiny_ic_databunch_valid_features(tiny_ic_databunch):
)
return features
# ------|-- Segmentation ---------------------------------------------
@pytest.fixture(scope="session")
def tiny_seg_data_path(tmp_session, seg_classes) -> str:
""" Returns the path to the segmentation tiny fridge objects dataset. """
path = unzip_url(
seg_urls.fridge_objects_tiny_path,
fpath=tmp_session,
dest=tmp_session,
exist_ok=True,
)
classes_path = Path(path) / "classes.txt"
with open(classes_path, "w") as f:
for c in seg_classes:
f.write(c + "\n")
return path
@pytest.fixture(scope="session")
def tiny_seg_databunch(tiny_seg_data_path, seg_classes):
""" Returns a databunch object for the segmentation tiny fridge objects dataset. """
get_gt_filename = (
lambda x: f"{tiny_seg_data_path}/segmentation-masks/{x.stem}.png"
)
return (
SegmentationItemList.from_folder(tiny_seg_data_path)
.split_by_rand_pct(valid_pct=0.1, seed=10)
.label_from_func(get_gt_filename, classes=seg_classes)
.transform(get_transforms(), tfm_y=True, size=50)
.databunch(bs=8, num_workers=db_num_workers())
.normalize(imagenet_stats)
)
@pytest.fixture(scope="session")
def seg_classes() -> List[str]:
""" Returns the segmentation class names. """
return ["background", "can", "carton", "milk_bottle", "water_bottle"]
@pytest.fixture(scope="session")
def seg_classes_path(tiny_seg_data_path) -> str:
""" Returns the path to file with class names. """
return Path(tiny_seg_data_path) / "classes.txt"
@pytest.fixture(scope="session")
def seg_im_mask_paths(tiny_seg_data_path) -> str:
""" Returns path to images and their corresponding masks. """
im_dir = Path(tiny_seg_data_path) / "images"
mask_dir = Path(tiny_seg_data_path) / "segmentation-masks"
im_paths = sorted(get_image_files(im_dir))
mask_paths = sorted(get_image_files(mask_dir))
return im_paths, mask_paths
@pytest.fixture(scope="session")
def seg_im_and_mask(seg_im_mask_paths) -> str:
""" Returns a single image with its mask. """
im = load_im(seg_im_mask_paths[0][0])
mask = load_mask(seg_im_mask_paths[1][0])
return im, mask
@pytest.fixture(scope="session")
def seg_learner(tiny_seg_databunch, seg_classes):
return unet_learner(
tiny_seg_databunch,
models.resnet18,
wd=1e-2,
metrics=get_objective_fct(seg_classes),
)
@pytest.fixture(scope="session")
def seg_prediction(seg_learner, seg_im_and_mask):
return predict(seg_im_and_mask[0], seg_learner)
@pytest.fixture(scope="session")
def seg_confusion_matrices(seg_learner, tiny_seg_databunch):
return confusion_matrix(seg_learner, tiny_seg_databunch.valid_dl)

Просмотреть файл

@ -0,0 +1,55 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import papermill as pm
import pytest
import numpy as np
import scrapbook as sb
# Parameters
KERNEL_NAME = "python3"
OUTPUT_NOTEBOOK = "output.ipynb"
@pytest.mark.notebooks
@pytest.mark.linuxgpu
def test_01_notebook_integration_run(segmentation_notebooks):
notebook_path = segmentation_notebooks["01"]
pm.execute_notebook(
notebook_path,
OUTPUT_NOTEBOOK,
parameters=dict(PM_VERSION=pm.__version__),
kernel_name=KERNEL_NAME,
)
nb_output = sb.read_notebook(OUTPUT_NOTEBOOK)
overall_accuracy = nb_output.scraps["validation_overall_accuracy"].data
class_accuracies = nb_output.scraps["validation_class_accuracies"].data
assert len(class_accuracies) == 5
assert overall_accuracy >= 90
for acc in class_accuracies:
assert acc > 80
@pytest.mark.notebooks
@pytest.mark.linuxgpu
def test_11_notebook_integration_run(segmentation_notebooks):
notebook_path = segmentation_notebooks["11"]
pm.execute_notebook(
notebook_path,
OUTPUT_NOTEBOOK,
parameters=dict(
PM_VERSION=pm.__version__,
REPS = 1,
),
kernel_name=KERNEL_NAME,
)
nb_output = sb.read_notebook(OUTPUT_NOTEBOOK)
nr_elements = nb_output.scraps["nr_elements"].data
ratio_correct = nb_output.scraps["ratio_correct"].data
max_duration = nb_output.scraps["max_duration"].data
min_duration = nb_output.scraps["min_duration"].data
assert nr_elements == 12
assert min_duration <= 0.8 * max_duration
assert np.max(ratio_correct) > 0.75

Просмотреть файл

@ -0,0 +1,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import requests
from utils_cv.classification.data import Urls
def test_urls():
# Test if all urls are valid
all_urls = Urls.all()
for url in all_urls:
with requests.get(url):
pass

Просмотреть файл

@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import fastai
import numpy as np
from utils_cv.segmentation.dataset import (
load_im,
load_mask,
read_classes,
mask_area_sizes,
)
def test_load_im(seg_im_mask_paths, seg_im_and_mask):
im = load_im(seg_im_mask_paths[0][0])
assert type(im) == fastai.vision.image.Image
im = load_im(seg_im_and_mask[0])
assert type(im) == fastai.vision.image.Image
def test_load_mask(seg_im_mask_paths, seg_im_and_mask):
mask = load_mask(seg_im_mask_paths[1][0])
assert type(mask) == fastai.vision.image.ImageSegment
mask = load_mask(seg_im_and_mask[1])
assert type(mask) == fastai.vision.image.ImageSegment
def test_read_classes(seg_classes_path, seg_classes):
classes = read_classes(seg_classes_path)
assert len(classes) == len(seg_classes)
for i in range(len(classes)):
assert classes[i] == seg_classes[i]
def test_mask_area_sizes(tiny_seg_databunch):
areas, pixel_counts = mask_area_sizes(tiny_seg_databunch)
assert len(areas) == 5
assert len(pixel_counts) == 5
assert np.sum([np.sum(v) for v in pixel_counts.values()]) == (22 * 499 * 666)
assert type(areas[0]) == list
for i in range(len(areas)):
for area in areas[i]:
assert area > 0

Просмотреть файл

@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import functools
import numpy as np
from utils_cv.segmentation.model import (
get_objective_fct,
predict,
confusion_matrix,
print_accuracies,
)
def test_get_objective_fct(seg_classes):
fct = get_objective_fct(seg_classes)
assert type(fct) == functools.partial
def test_predict(seg_im_mask_paths, seg_learner):
im_path = seg_im_mask_paths[0][0]
mask, scores = predict(im_path, seg_learner)
assert mask.shape[0] == 50 # scores.shape[0] == 50
assert mask.shape[1] == 50 # scores.shape[1] == 50
assert len(scores) == 5
for i in range(len(scores)):
assert mask.shape[0] == scores[i].shape[0]
assert mask.shape[1] == scores[i].shape[1]
def test_confusion_matrix(seg_learner, tiny_seg_databunch):
cmat, cmat_norm = confusion_matrix(
seg_learner, tiny_seg_databunch.valid_dl
)
assert type(cmat) == np.ndarray
assert type(cmat_norm) == np.ndarray
assert cmat.max() > 1.0
assert cmat_norm.max() <= 1.0
def test_print_accuracies(seg_confusion_matrices, seg_classes):
cmat, cmat_norm = seg_confusion_matrices
print_accuracies(cmat, cmat_norm, seg_classes)

Просмотреть файл

@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import papermill as pm
import pytest
import scrapbook as sb
# Parameters
KERNEL_NAME = "python3"
OUTPUT_NOTEBOOK = "output.ipynb"
@pytest.mark.notebooks
def test_01_notebook_run(segmentation_notebooks, tiny_seg_data_path):
notebook_path = segmentation_notebooks["01"]
pm.execute_notebook(
notebook_path,
OUTPUT_NOTEBOOK,
parameters=dict(
PM_VERSION=pm.__version__,
EPOCHS=1,
IM_SIZE=50,
DATA_PATH=tiny_seg_data_path
),
kernel_name=KERNEL_NAME,
)
nb_output = sb.read_notebook(OUTPUT_NOTEBOOK)
overall_accuracy = nb_output.scraps["validation_overall_accuracy"].data
class_accuracies = nb_output.scraps["validation_class_accuracies"].data
assert len(class_accuracies) == 5
@pytest.mark.notebooks
def test_11_notebook_run(segmentation_notebooks, tiny_seg_data_path):
notebook_path = segmentation_notebooks["11"]
pm.execute_notebook(
notebook_path,
OUTPUT_NOTEBOOK,
parameters=dict(
PM_VERSION=pm.__version__,
REPS = 1,
EPOCHS=[1],
IM_SIZE=[50],
LEARNING_RATES = [1e-4],
DATA_PATH=[tiny_seg_data_path]
),
kernel_name=KERNEL_NAME,
)
nb_output = sb.read_notebook(OUTPUT_NOTEBOOK)
nr_elements = nb_output.scraps["nr_elements"].data
ratio_correct = nb_output.scraps["ratio_correct"].data
max_duration = nb_output.scraps["max_duration"].data
min_duration = nb_output.scraps["min_duration"].data
assert nr_elements == 2

Просмотреть файл

@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from utils_cv.segmentation.plot import (
plot_image_and_mask,
plot_segmentation,
plot_mask_stats,
plot_confusion_matrix,
)
def test_plot_image_and_mask(seg_im_and_mask):
plot_image_and_mask(seg_im_and_mask[0], seg_im_and_mask[1])
def test_plot_segmentation(seg_im_and_mask, seg_prediction):
mask, scores = seg_prediction
plot_segmentation(seg_im_and_mask[0], mask, scores)
def test_plot_mask_stats(tiny_seg_databunch, seg_classes):
plot_mask_stats(tiny_seg_databunch, seg_classes)
plot_mask_stats(
tiny_seg_databunch, seg_classes, exclude_classes=["background"]
)
def test_plot_confusion_matrix(seg_confusion_matrices, seg_classes):
cmat, cmat_norm = seg_confusion_matrices
plot_confusion_matrix(cmat, cmat_norm, seg_classes)

Просмотреть файл

@ -21,12 +21,16 @@ from fastai.vision import (
imagenet_stats,
Learner,
models,
SegmentationItemList,
unet_learner,
)
from matplotlib.axes import Axes
from matplotlib.text import Annotation
import pandas as pd
from utils_cv.common.gpu import db_num_workers
from utils_cv.segmentation.dataset import read_classes
from utils_cv.segmentation.model import get_objective_fct
Time = float
@ -162,7 +166,8 @@ def plot_sweeper_df(
for col, ax in zip(cols, axes):
top_val = df[col].max()
ax.set_ylim(top=top_val * 1.2)
min_val = df[col].min()
ax.set_ylim(bottom = min_val/1.01, top=top_val * 1.01)
add_value_labels(ax)
if col in ["accuracy"]:
@ -234,7 +239,7 @@ class ParameterSweeper:
return permutations
@staticmethod
def _get_data_bunch(
def _get_data_bunch_imagelist(
path: Union[Path, str], transform: bool, im_size: int, bs: int
) -> ImageDataBunch:
"""
@ -261,6 +266,40 @@ class ParameterSweeper:
.normalize(imagenet_stats)
)
@staticmethod
def _get_data_bunch_segmentationitemlist(
path: Union[Path, str], transform: bool, im_size: int, bs: int, classes:List[str]
) -> ImageDataBunch:
"""
Create ImageDataBunch and return it. TODO in future version is to allow
users to pass in their own image bunch or their own Transformation
objects (instead of using fastai's <get_transforms>)
Args:
path (Union[Path, str]): path to data to create databunch with
transform (bool): a flag to set fastai default transformations (get_transforms())
im_size (int): image size of databunch
bs (int): batch size of databunch
Returns:
ImageDataBunch
"""
path = path if type(path) is Path else Path(path)
tfms = get_transforms() if transform else None
im_path = path / "images"
anno_path = path / "segmentation-masks"
get_gt_filename = lambda x: anno_path / f"{x.stem}.png"
# Load data
return (
SegmentationItemList.from_folder(im_path)
.split_by_rand_pct(valid_pct=0.33)
.label_from_func(get_gt_filename, classes=classes)
.transform(tfms=tfms, size=im_size, tfm_y=True)
.databunch(bs=bs, num_workers=db_num_workers())
.normalize(imagenet_stats)
)
@staticmethod
def _early_stopping_callback(
metric: str = "accuracy", min_delta: float = 0.01, patience: int = 3
@ -324,7 +363,7 @@ class ParameterSweeper:
)
def _learn(
self, data_path: Path, params: Tuple[Any], stop_early: bool
self, data_path: Path, params: Tuple[Any], stop_early: bool, learner_type = "cnn"
) -> Tuple[Learner, Time]:
"""
Given a set of permutations, create a learner to train and validate on
@ -353,19 +392,39 @@ class ParameterSweeper:
one_cycle_policy = params["one_cycle_policy"]
weight_decay = params["weight_decay"]
data = self._get_data_bunch(data_path, transform, im_size, batch_size)
callbacks = list()
if stop_early:
callbacks.append(ParameterSweeper._early_stopping_callback())
learn = cnn_learner(
data,
architecture.value,
metrics=accuracy,
ps=dropout,
callback_fns=callbacks,
)
# Initialize CNN learner
if learner_type == "cnn":
data = self._get_data_bunch_imagelist(data_path, transform, im_size, batch_size)
learn = cnn_learner(
data,
architecture.value,
metrics=accuracy,
ps=dropout,
callback_fns=callbacks,
)
# Initialize UNet learner
elif learner_type == "unet":
classes = read_classes(os.path.join(data_path, "classes.txt"))
data = self._get_data_bunch_segmentationitemlist(data_path, transform, im_size, batch_size, classes)
metric = get_objective_fct(classes)
metric.__name__ = "ratio_correct"
learn = unet_learner(
data,
architecture.value,
wd=1e-2,
metrics=metric,
callback_fns=callbacks,
)
else:
print(f"Mode learner_type={learner_type} not supported.")
head_learning_rate = learning_rate
body_learning_rate = (
@ -429,6 +488,7 @@ class ParameterSweeper:
reps: int = 3,
early_stopping: bool = False,
metric_fct=None,
learner_type = "cnn"
) -> pd.DataFrame:
""" Performs the experiment.
Iterates through the number of specified <reps>, the list permutations
@ -439,40 +499,52 @@ class ParameterSweeper:
definition.
Args:
datasets (List[Path]): A list of datasets to iterate over.
reps (int): The number of runs to loop over.
early_stopping (bool): Whether we want to perform early stopping.
datasets: A list of datasets to iterate over.
reps: The number of runs to loop over.
early_stopping: Whether we want to perform early stopping.
metric_fct: custom metric function
learner_type: choose between "cnn" and "unet" learners
Returns:
pd.DataFrame: a multi-index dataframe with the results stored in it.
"""
count = 0
res = dict()
for rep in range(reps):
res[rep] = dict()
for i, permutation in enumerate(self.permutations):
print(
f"Running {i+1} of {len(self.permutations)} permutations. "
f"Repeat {rep+1} of {reps}."
)
stringified_permutation = self._serialize_permutations(
permutation
)
res[rep][stringified_permutation] = dict()
for dataset in datasets:
for ii, dataset in enumerate(datasets):
percent_done = round(100.0 * count / (reps * len(self.permutations) * len(datasets)))
print(
f"Percentage done: {percent_done}%. "
f"Currently processing repeat {rep+1} of {reps}, "
f"running {i+1} of {len(self.permutations)} permutations, "
f"dataset {ii+1} of {len(datasets)} ({os.path.basename(dataset)}). "
)
data_name = os.path.basename(dataset)
res[rep][stringified_permutation][data_name] = dict()
learn, duration = self._learn(
dataset, permutation, early_stopping
dataset, permutation, early_stopping, learner_type
)
if metric_fct is None:
if metric_fct is None and learner_type == "cnn":
_, metric = learn.validate(
learn.data.valid_dl, metrics=[accuracy]
learn.data.valid_dl,
metrics=[accuracy]
)
elif learner_type == "unet":
_, metric = learn.validate(
learn.data.valid_dl
)
else:
@ -488,4 +560,6 @@ class ParameterSweeper:
learn.destroy()
count+=1
return self._make_df_from_dict(res)

Просмотреть файл

Просмотреть файл

@ -0,0 +1,17 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from typing import List
from urllib.parse import urljoin
class Urls:
# base url
base = "https://cvbp.blob.core.windows.net/public/datasets/image_segmentation/"
# traditional datasets
fridge_objects_path = urljoin(base, "segFridgeObjects.zip")
fridge_objects_tiny_path = urljoin(base, "segFridgeObjectsTiny.zip")
@classmethod
def all(cls) -> List[str]:
return [v for k, v in cls.__dict__.items() if k.endswith("_path")]

Просмотреть файл

@ -0,0 +1,98 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import collections
from pathlib import Path
from typing import List, Union
import fastai
from fastai.vision import open_image, open_mask
from fastai.vision.data import ImageDataBunch
import numpy as np
from numpy import loadtxt
import PIL
from scipy import ndimage
def load_im(
im_or_path: Union[np.ndarray, Union[str, Path]]
) -> fastai.vision.image.Image:
""" Load image using "open_image" function from fast.ai.
Args:
im_or_path: image object or image location to be loaded
Return:
Image
"""
if isinstance(im_or_path, (str, Path)):
im = open_image(im_or_path, convert_mode="RGB")
else:
im = im_or_path
return im
def load_mask(
mask_or_path: Union[np.ndarray, Union[str, Path]]
) -> fastai.vision.image.ImageSegment:
""" Load mask using "open_mask" function from fast.ai.
Args:
mask_or_path: mask object or mask location to be loaded
Return:
Mask
"""
if isinstance(mask_or_path, (str, Path)):
mask = open_mask(mask_or_path)
else:
mask = mask_or_path
return mask
def read_classes(path: Union[str, Path]) -> List[str]:
""" Read text file with class names.
Args:
path: location of text file where each line is a class name
Return:
List of class names
"""
classes = list(loadtxt(path, dtype=str))
classes = [s.lower() for s in classes]
return classes
def mask_area_sizes(data: ImageDataBunch) -> collections.defaultdict:
""" Compute number of pixels in each connected segment.
Args:
data: databunch with images and ground truth masks
Return:
Sizes of all connected segments, in pixels, and for each class
"""
seg_areas = collections.defaultdict(list)
pixel_counts = collections.defaultdict(list)
# Loop over all class masks
for mask_path in data.y.items:
mask = np.array(PIL.Image.open(mask_path))
# For each class, find all segments and enumerate
for class_id in np.unique(mask):
num_pixels = np.sum(mask == class_id)
pixel_counts[class_id].append(num_pixels)
# Get all connected segments in image
segments, _ = ndimage.label(
mask == class_id,
structure=[[1, 1, 1], [1, 1, 1], [1, 1, 1]]
)
# Loop over each segment of a given label
for segment_id in range(1, segments.max() + 1):
area = np.sum(segments == segment_id)
seg_areas[class_id].append(area)
return seg_areas, pixel_counts

Просмотреть файл

@ -0,0 +1,140 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from functools import partial
from pathlib import Path
from typing import List, Union
from fastai.basic_data import DeviceDataLoader
from fastai.basic_train import Learner
import numpy as np
import PIL
from sklearn.metrics import confusion_matrix as sk_confusion_matrix
from .dataset import load_im
# Ignore pixels marked as void. That could be pixels which are hard to annotate and hence should not influence training.
def _objective_fct_partial(void_id, input, target):
""" Helper function to compute the ratio of correctly classified pixels. """
target = target.squeeze(1)
if void_id:
mask = target != void_id
ratio_correct = (
(input.argmax(dim=1)[mask] == target[mask]).float().mean()
)
else:
ratio_correct = (input.argmax(dim=1) == target).float().mean()
return ratio_correct
def get_objective_fct(classes: List[str]):
""" Returns objective function for model training, defined as ratio of correctly classified pixels.
Args:
classes: list of class names
Return:
Objective function.
"""
class2id = {v: k for k, v in enumerate(classes)}
if "void" in class2id:
void_id = class2id["void"]
else:
void_id = None
return partial(_objective_fct_partial, void_id)
def predict(
im_or_path: Union[np.ndarray, Union[str, Path]],
learn: Learner,
thres: float = None,
) -> [np.ndarray, np.ndarray]:
""" Run model inference.
Args:
im_or_path: image or path to image
learn: trained model
thres: threshold under which to reject predicted label and set to class-id 0 instead.
Return:
The predicted mask with pixel-wise confidence scores.
"""
im = load_im(im_or_path)
_, mask, scores = learn.predict(im, thresh=thres)
mask = np.array(mask).squeeze()
scores = np.array(scores)
# Fastai seems to ignore the confidance threshold 'thresh'. Hence here
# setting all predictions with low confidence to be 'background'.
if thres is not None:
max_scores = np.max(np.array(scores), axis=0)
mask[max_scores <= thres] = 0
return mask, scores
def confusion_matrix(
learn: Learner,
dl: DeviceDataLoader,
thres: float = None
) -> [np.ndarray, np.ndarray]:
""" Compute confusion matrix.
Args:
learn: trained model
dl: dataloader with images and ground truth masks
thres: threshold under which to reject predicted label and set to class-id 0 instead.
Return:
The un-normalized and the normalized confusion matrices.
"""
y_gts = []
y_preds = []
# Loop over all images
for im_path, gt_path in zip(dl.x.items, dl.y.items):
pred_mask, _ = predict(im_path, learn, thres)
# load ground truth and resize to be same size as predited mask
gt_mask = PIL.Image.open(gt_path)
gt_mask = gt_mask.resize(
pred_mask.shape[::-1], resample=PIL.Image.NEAREST
)
gt_mask = np.asarray(gt_mask)
# Store predicted and ground truth labels
assert len(gt_mask.flatten()) == len(pred_mask.flatten())
y_gts.extend(gt_mask.flatten())
y_preds.extend(pred_mask.flatten())
# Compute confusion matrices
cmat = sk_confusion_matrix(y_gts, y_preds)
cmat_norm = sk_confusion_matrix(y_gts, y_preds, normalize="true")
return cmat, cmat_norm
def print_accuracies(
cmat: np.ndarray, cmat_norm: np.ndarray, classes: List[str]
) -> [int, int]:
""" Print accuracies per class, and the overall class-averaged accuracy.
Args:
cmat: confusion matrix (with raw pixel counts)
cmat_norm: normalized confusion matrix
classes: list of class names
Return:
Computed overall and per-class accuracies.
"""
class_accs = 100.0 * np.diag(cmat_norm)
overall_acc = 100.0 * np.diag(cmat).sum() / cmat.sum()
print(f"Overall accuracy: {overall_acc:3.2f}%")
print(f"Class-averaged accuracy: {np.mean(class_accs):3.2f}%")
for acc, cla in zip(class_accs, classes):
print(f"\tClass {cla:>15} has accuracy: {acc:2.2f}%")
return overall_acc, class_accs

Просмотреть файл

@ -0,0 +1,194 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from pathlib import Path
from typing import List, Tuple, Union
from fastai.vision import pil2tensor, show_image
from fastai.vision.data import ImageDataBunch
from matplotlib import cm
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import ConfusionMatrixDisplay
from .dataset import load_im, load_mask, mask_area_sizes
# Plot original image(left), ground truth (middle), and overlaid ground truth (right)
def plot_image_and_mask(
im_or_path: Union[np.ndarray, Union[str, Path]],
mask_or_path: Union[np.ndarray, Union[str, Path]],
show: bool = True,
figsize: Tuple[int, int] = (16, 8),
alpha=0.50,
cmap: ListedColormap = cm.get_cmap("Set3"),
) -> None:
""" Plot an image and its ground truth mask.
Args:
im_or_path: image or path to image
mask_or_path: mask or path to mask
show: set to true to call matplotlib's show()
figsize: figure size
alpha: strength of overlying image on mask.
cmap: mask color map.
"""
im = load_im(im_or_path)
mask = load_mask(mask_or_path)
# Plot the image, the mask, and the mask overlaid on image
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)
show_image(im, ax=ax1)
show_image(mask, ax=ax2, cmap=cmap)
im.show(y=mask, ax=ax3, cmap=cmap, alpha=alpha)
ax1.set_title("Image")
ax2.set_title("Mask")
ax3.set_title("Mask (overlaid on Image)")
if show:
plt.show()
def plot_segmentation(
im_or_path: Union[np.ndarray, Union[str, Path]],
pred_mask: Union[np.ndarray, Union[str, Path]],
pred_scores: np.ndarray,
gt_mask_or_path: Union[np.ndarray, Union[str, Path]] = None,
show: bool = True,
figsize: Tuple[int, int] = (16, 4),
cmap: ListedColormap = cm.get_cmap("Set3"),
) -> None:
""" Plot an image, its predicted mask with associated scores, and optionally the ground truth mask.
Args:
im_or_path: image or path to image
pred_mask: predicted mask
pred_scores: pixel-wise confidence scores in the predictions
gt_mask_or_path: ground truth mask or path to mask
show: set to true to call matplotlib's show()
figsize: figure size
cmap: mask color map.
"""
im = load_im(im_or_path)
pred_mask = pil2tensor(pred_mask, np.float32)
max_scores = np.max(np.array(pred_scores[1:]), axis=0)
max_scores = pil2tensor(max_scores, np.float32)
# Plot groud truth mask if provided
if gt_mask_or_path:
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=figsize)
gt_mask = load_mask(gt_mask_or_path)
show_image(gt_mask, ax=ax4, cmap=cmap)
ax4.set_title("Ground truth mask")
else:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)
# Plot image, predicted mask, and prediction scores
show_image(im, ax=ax1)
show_image(pred_mask, ax=ax2, cmap=cmap)
show_image(max_scores, ax=ax3, cmap=cm.get_cmap("gist_heat"))
ax1.set_title("Image")
ax2.set_title("Predicted mask")
ax3.set_title("Predicted scores")
if show:
plt.show()
def plot_mask_stats(
data: ImageDataBunch,
classes: List[str],
show: bool = True,
figsize: Tuple[int, int] = (15, 3),
nr_bins: int = 50,
exclude_classes: list = None,
) -> None:
""" Plot statistics of the ground truth masks such as number or size of segments.
Args:
data: databunch with images and ground truth masks
classes: list of class names
show: set to true to call matplotlib's show()
figsize: figure size
nr_bins: number of bins for segment sizes histogram
exclude_classes: list of classes to ignore, e.g. ["background"]
"""
areas, pixel_counts = mask_area_sizes(data)
class_names = [classes[k] for k,v in areas.items()]
values_list = [v for k,v in areas.items()]
seg_counts = [len(v) for v in values_list]
pixel_counts = [np.sum(v) for v in pixel_counts.values()]
assert exclude_classes is None or type(exclude_classes) == list
# Remove specified classes
if exclude_classes:
keep_indices = np.where(
[c not in set(exclude_classes) for c in class_names]
)[0]
class_names = [class_names[i] for i in keep_indices]
values_list = [values_list[i] for i in keep_indices]
seg_counts = [seg_counts[i] for i in keep_indices]
pixel_counts = [pixel_counts[i] for i in keep_indices]
# Left plot
plt.subplots(1, 3, figsize=figsize)
plt.subplot(1, 3, 1)
plt.barh(range(len(class_names)), pixel_counts)
plt.gca().set_yticks(range(len(class_names)))
plt.gca().set_yticklabels(class_names)
plt.xlabel("Number of pixels per class")
plt.title("Distribution of pixel labels")
# Middle plot
plt.subplot(1, 3, 2)
plt.barh(range(len(class_names)), seg_counts)
plt.gca().set_yticks(range(len(class_names)))
plt.gca().set_yticklabels(class_names)
plt.xlabel("Number of segments per class")
plt.title("Distribution of segment labels")
# Right plot
plt.subplot(1, 3, 3)
plt.hist(
values_list, nr_bins, label=class_names, histtype="barstacked",
)
plt.title("Distribution of segment sizes (stacked bar chart)")
plt.legend()
plt.ylabel("Number of segments")
plt.xlabel("Segment sizes [area in pixel]")
if show:
plt.show()
def plot_confusion_matrix(
cmat: np.ndarray,
cmat_norm: np.ndarray,
classes: List[str],
show: bool = True,
figsize: Tuple[int, int] = (16, 4),
) -> None:
""" Plot the confusion matrices.
Args:
cmat: confusion matrix (with raw pixel counts)
cmat_norm: normalized confusion matrix
classes: list of class names
show: set to true to call matplotlib's show()
figsize: figure size
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
ConfusionMatrixDisplay(cmat, classes).plot(
ax=ax1,
cmap=cm.get_cmap("Blues"),
xticks_rotation="vertical",
values_format="d",
)
ConfusionMatrixDisplay(cmat_norm, classes).plot(
ax=ax2, cmap=cm.get_cmap("Blues"), xticks_rotation="vertical"
)
ax1.set_title("Confusion matrix")
ax2.set_title("Normalized confusion matrix")
if show:
plt.show()