2019-04-25 22:35:25 +03:00
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
# Licensed under the MIT License.
|
|
|
|
|
2019-06-26 19:32:52 +03:00
|
|
|
import requests
|
2019-04-25 22:35:25 +03:00
|
|
|
from PIL import Image
|
|
|
|
from fastai.vision.data import ImageList
|
|
|
|
|
2019-05-14 06:04:22 +03:00
|
|
|
from utils_cv.classification.data import (
|
|
|
|
downsize_imagelist,
|
2019-06-26 19:32:52 +03:00
|
|
|
imagenet_labels,
|
2019-05-14 06:04:22 +03:00
|
|
|
is_data_multilabel,
|
2019-06-26 19:32:52 +03:00
|
|
|
Urls,
|
2019-05-14 06:04:22 +03:00
|
|
|
)
|
2019-04-25 22:35:25 +03:00
|
|
|
|
|
|
|
|
|
|
|
def test_imagenet_labels():
|
|
|
|
# Compare first five labels for quick check
|
|
|
|
IMAGENET_LABELS_FIRST_FIVE = (
|
|
|
|
"tench",
|
|
|
|
"goldfish",
|
|
|
|
"great_white_shark",
|
|
|
|
"tiger_shark",
|
|
|
|
"hammerhead",
|
|
|
|
)
|
|
|
|
|
|
|
|
labels = imagenet_labels()
|
|
|
|
for i in range(5):
|
|
|
|
assert labels[i] == IMAGENET_LABELS_FIRST_FIVE[i]
|
|
|
|
|
2019-06-26 19:32:52 +03:00
|
|
|
# Check total number of labels
|
|
|
|
assert len(labels) == 1000
|
|
|
|
|
2019-04-25 22:35:25 +03:00
|
|
|
|
|
|
|
def test_downsize_imagelist(tiny_ic_data_path, tmp):
|
|
|
|
im_list = ImageList.from_folder(tiny_ic_data_path)
|
|
|
|
max_dim = 50
|
|
|
|
downsize_imagelist(im_list, tmp, max_dim)
|
|
|
|
im_list2 = ImageList.from_folder(tmp)
|
|
|
|
assert len(im_list) == len(im_list2)
|
|
|
|
for im_path in im_list2.items:
|
|
|
|
assert min(Image.open(im_path).size) <= max_dim
|
2019-05-14 06:04:22 +03:00
|
|
|
|
|
|
|
|
|
|
|
def test_is_data_multilabel(tiny_multilabel_ic_data_path, tiny_ic_data_path):
|
|
|
|
"""
|
|
|
|
Tests that multilabel classification datasets and traditional
|
|
|
|
classification datasets are correctly identified
|
|
|
|
"""
|
|
|
|
assert is_data_multilabel(tiny_multilabel_ic_data_path)
|
|
|
|
assert not is_data_multilabel(tiny_ic_data_path)
|
2019-06-26 19:32:52 +03:00
|
|
|
|
|
|
|
|
|
|
|
def test_urls():
|
|
|
|
# Test if all urls are valid
|
|
|
|
all_urls = Urls.all()
|
|
|
|
for url in all_urls:
|
2019-09-25 18:08:01 +03:00
|
|
|
with requests.get(url):
|
2019-06-26 19:32:52 +03:00
|
|
|
pass
|