2019-07-23 00:12:30 +03:00
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
# Licensed under the MIT License.
|
|
|
|
|
2019-06-26 19:32:52 +03:00
|
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
from utils_cv.classification.widget import AnnotationWidget, ResultsWidget
|
|
|
|
|
|
|
|
|
|
|
|
def test_annotation_widget(tiny_ic_data_path, tmp):
|
|
|
|
ANNO_PATH = os.path.join(tmp, "cvbp_ic_annotation.txt")
|
|
|
|
|
|
|
|
w_anno_ui = AnnotationWidget(
|
|
|
|
labels=["can", "carton", "milk_bottle", "water_bottle"],
|
|
|
|
im_dir=os.path.join(tiny_ic_data_path, "can"),
|
|
|
|
anno_path=ANNO_PATH,
|
|
|
|
im_filenames=None, # Set to None to annotate all images in IM_DIR
|
|
|
|
)
|
|
|
|
w_anno_ui.update_ui()
|
|
|
|
|
|
|
|
|
|
|
|
def test_results_widget(model_pred_scores):
|
|
|
|
learn, pred_scores = model_pred_scores
|
|
|
|
|
|
|
|
w_results = ResultsWidget(
|
|
|
|
dataset=learn.data.valid_ds,
|
|
|
|
y_score=pred_scores,
|
2019-09-25 18:08:01 +03:00
|
|
|
y_label=[
|
|
|
|
learn.data.classes[x] for x in np.argmax(pred_scores, axis=1)
|
|
|
|
],
|
2019-06-26 19:32:52 +03:00
|
|
|
)
|
|
|
|
w_results.update()
|