computervision-recipes/tests/unit/classification/test_classification_model.py

161 строка
4.9 KiB
Python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import numpy as np
import urllib.request
from torch import tensor
from fastai.metrics import accuracy, error_rate
from fastai.vision import (
cnn_learner,
DatasetType,
ImageList,
imagenet_stats,
models,
open_image,
)
from utils_cv.classification.model import (
get_optimal_threshold,
get_preds,
hamming_accuracy,
IMAGENET_IM_SIZE,
model_to_learner,
TrainMetricsRecorder,
zero_one_accuracy,
)
def test_hamming_accuracy(multilabel_result):
""" Test the hamming loss evaluation metric function. """
y_pred, y_true = multilabel_result
assert hamming_accuracy(y_pred, y_true) == tensor(1.0 - 0.1875)
assert hamming_accuracy(y_pred, y_true, sigmoid=True) == tensor(
1.0 - 0.375
)
assert hamming_accuracy(y_pred, y_true, threshold=1.0) == tensor(
1.0 - 0.625
)
def test_zero_one_accuracy(multilabel_result):
""" Test the zero-one loss evaluation metric function. """
y_pred, y_true = multilabel_result
assert zero_one_accuracy(y_pred, y_true) == tensor(1.0 - 0.75)
assert zero_one_accuracy(y_pred, y_true, sigmoid=True) == tensor(
1.0 - 0.75
)
assert zero_one_accuracy(y_pred, y_true, threshold=1.0) == tensor(
1.0 - 1.0
)
def test_get_optimal_threshold(multilabel_result):
""" Test the get_optimal_threshold function. """
y_pred, y_true = multilabel_result
assert get_optimal_threshold(hamming_accuracy, y_pred, y_true) == 0.05
assert (
get_optimal_threshold(
hamming_accuracy, y_pred, y_true, thresholds=np.linspace(0, 1, 11)
)
== 0.1
)
assert get_optimal_threshold(zero_one_accuracy, y_pred, y_true) == 0.05
assert (
get_optimal_threshold(
zero_one_accuracy, y_pred, y_true, thresholds=np.linspace(0, 1, 11)
)
== 0.1
)
def test_model_to_learner(tmp):
model = models.resnet18
# Test if the function loads an ImageNet model (ResNet) trainer
learn = model_to_learner(model(pretrained=True))
assert len(learn.data.classes) == 1000 # Check Image net classes
assert isinstance(learn.model, models.ResNet)
# Test if model can predict very simple image
IM_URL = (
"https://cvbp-secondary.z19.web.core.windows.net/images/cvbp_cup.jpg"
)
imagefile = os.path.join(tmp, "cvbp_cup.jpg")
urllib.request.urlretrieve(IM_URL, imagefile)
category, ind, predict_output = learn.predict(
open_image(imagefile, convert_mode="RGB")
)
assert learn.data.classes[ind] == str(category) == "coffee_mug"
# Test if .predict() yield the same output when use .get_preds()
one_data = (
ImageList.from_folder(tmp)
.split_none()
.label_const() # cannot use label_empty because of fastai bug: # https://github.com/fastai/fastai/issues/1908
.transform(tfms=None, size=IMAGENET_IM_SIZE)
.databunch(bs=1)
.normalize(imagenet_stats)
)
learn.data.train_dl = one_data.train_dl
get_preds_output = learn.get_preds(ds_type=DatasetType.Train)
assert np.all(
np.isclose(
np.array(
get_preds_output[0].tolist()[0]
), # Note, get_preds() produces a batch (list) output
np.array(predict_output.tolist()),
rtol=1e-05,
atol=1e-08,
)
)
def test_train_metrics_recorder(tiny_ic_databunch):
model = models.resnet18
lr = 1e-4
epochs = 2
def test_callback(learn):
tmr = TrainMetricsRecorder(learn)
learn.callbacks.append(tmr)
learn.fit(epochs, lr)
return tmr
# multiple metrics
learn = cnn_learner(
tiny_ic_databunch, model, metrics=[accuracy, error_rate]
)
cb = test_callback(learn)
assert len(cb.train_metrics) == len(cb.valid_metrics) == epochs
assert (
len(cb.train_metrics[0]) == len(cb.valid_metrics[0]) == 2
) # we used 2 metrics
# no metrics
learn = cnn_learner(tiny_ic_databunch, model)
cb = test_callback(learn)
assert len(cb.train_metrics) == len(cb.valid_metrics) == 0 # no metrics
# no validation set
learn = cnn_learner(tiny_ic_databunch, model, metrics=accuracy)
valid_dl = learn.data.valid_dl
learn.data.valid_dl = None
cb = test_callback(learn)
assert len(cb.train_metrics) == epochs
assert len(cb.train_metrics[0]) == 1 # we used 1 metrics
assert len(cb.valid_metrics) == 0 # no validation
# Since tiny_ic_databunch is being used in other tests too, should put the validation data back.
learn.data.valid_dl = valid_dl
def test_get_preds(model_pred_scores):
learn, ref_pred_scores = model_pred_scores
pred_outs = get_preds(learn, learn.data.valid_dl)
assert len(pred_outs[0]) == len(learn.data.valid_ds)
# Check if the output is the same as learn.get_preds
assert pred_outs[0].tolist() == ref_pred_scores