OD model and data util functions in rai_test_utils (#2246)
* OD model & data ckpt * torch dependency * removed 3.6 from PR gate * removed 3.6 from PR gate * import fixes * str format fix * image utils * model util function * dependency updates * added constants * unit tests * lint fixes * sorted imports * dependency update * install updates * gate tweak * gate tweak * gate tweak * gate tweak * gate tweak * gate tweak * torch install tweak * torch install tweak * torch install tweak * torch install tweak * torch install tweak * torch install tweak * updated gate test for torch installation on pip * updated gate test for torch installation on pip * updated gate test for torch installation on pip * updated gate test for torch installation on pip * removed torch installation * torch ckpt * torch conda install * conda setup * gate tweak * bash test * gate py + conda * bash support to py commands * bash support to py commands * bash support to py commands * bash support to py commands * bash support to py commands * added mkl support * added mkl support * disabled py 3.7 due to pip internal import error * raiutils dependency update * raiutils dependency update * numpy+mkl fixes * numpy+mkl fixes * test revert * gate update * revert dependencies * removed pip compile * removed pip sync * convert CI-python to use conda and install pytorch dependencies * convert CI-python to use conda and install pytorch dependencies * dependency update * auto lint fixes * added torch & torchvision versions for macos compat * version revert * disabled torch incompat test --------- Co-authored-by: Ilya Matiach <ilmat@microsoft.com>
This commit is contained in:
Родитель
928890a909
Коммит
f258b11bc1
|
@ -27,6 +27,8 @@ jobs:
|
||||||
pythonVersion: "3.7"
|
pythonVersion: "3.7"
|
||||||
- operatingSystem: windows-latest
|
- operatingSystem: windows-latest
|
||||||
pythonVersion: "3.7"
|
pythonVersion: "3.7"
|
||||||
|
- operatingSystem: macos-latest
|
||||||
|
pythonVersion: "3.8"
|
||||||
|
|
||||||
runs-on: ${{ matrix.operatingSystem }}
|
runs-on: ${{ matrix.operatingSystem }}
|
||||||
|
|
||||||
|
@ -74,6 +76,13 @@ jobs:
|
||||||
pip install -r requirements-dev.txt
|
pip install -r requirements-dev.txt
|
||||||
working-directory: ${{ matrix.packageDirectory }}
|
working-directory: ${{ matrix.packageDirectory }}
|
||||||
|
|
||||||
|
- if: ${{ matrix.packageDirectory == 'rai_test_utils' }}
|
||||||
|
name: Install package extras
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
pip install -r requirements-object-detection.txt
|
||||||
|
working-directory: ${{ matrix.packageDirectory }}
|
||||||
|
|
||||||
- name: Install package
|
- name: Install package
|
||||||
shell: bash -l {0}
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
# Copyright (c) Microsoft Corporation
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""Namespace for vision datasets."""
|
||||||
|
|
||||||
|
from .object_detection_data_utils import (get_images,
|
||||||
|
load_fridge_object_detection_dataset)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_images",
|
||||||
|
"load_fridge_object_detection_dataset"
|
||||||
|
]
|
|
@ -0,0 +1,214 @@
|
||||||
|
# Copyright (c) Microsoft Corporation
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import urllib.request as request_file
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from io import BytesIO
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from zipfile import ZipFile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
from requests.packages.urllib3.util.retry import Retry
|
||||||
|
|
||||||
|
# domain mapped session for reuse
|
||||||
|
_requests_sessions = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_retry_session(url):
|
||||||
|
domain = urlparse(url.lower()).netloc
|
||||||
|
if domain in _requests_sessions:
|
||||||
|
return _requests_sessions[domain]
|
||||||
|
|
||||||
|
session = requests.Session()
|
||||||
|
retries = Retry(
|
||||||
|
total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]
|
||||||
|
)
|
||||||
|
session.mount("http://", HTTPAdapter(max_retries=retries))
|
||||||
|
session.mount("https://", HTTPAdapter(max_retries=retries))
|
||||||
|
_requests_sessions[domain] = session
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_from_path(image_path, image_mode):
|
||||||
|
"""Get image from path.
|
||||||
|
|
||||||
|
:param image_path: The path to the image.
|
||||||
|
:type image_path: str
|
||||||
|
:param image_mode: The mode to open the image in.
|
||||||
|
See pillow documentation for all modes:
|
||||||
|
https://pillow.readthedocs.io/en/stable/handbook/concepts.html
|
||||||
|
:type image_mode: str
|
||||||
|
:return: The image as a numpy array.
|
||||||
|
:rtype: numpy.ndarray
|
||||||
|
"""
|
||||||
|
image_open_pointer = image_path
|
||||||
|
if image_path.startswith("http://") or image_path.startswith("https://"):
|
||||||
|
response = _get_retry_session(image_path).get(image_path)
|
||||||
|
image_open_pointer = BytesIO(response.content)
|
||||||
|
with Image.open(image_open_pointer) as im:
|
||||||
|
if image_mode is not None:
|
||||||
|
im = im.convert(image_mode)
|
||||||
|
image_array = np.asarray(im)
|
||||||
|
return image_array
|
||||||
|
|
||||||
|
|
||||||
|
def convert_images(dataset, image_mode):
|
||||||
|
"""Converts the images to the format required by the model.
|
||||||
|
|
||||||
|
If the images are base64 encoded, they are decoded and converted to
|
||||||
|
numpy arrays. If the images are already numpy arrays, they are
|
||||||
|
returned as is.
|
||||||
|
|
||||||
|
:param dataset: The dataset to convert.
|
||||||
|
:type dataset: numpy.ndarray
|
||||||
|
:param image_mode: The mode to open the image in.
|
||||||
|
See pillow documentation for all modes:
|
||||||
|
https://pillow.readthedocs.io/en/stable/handbook/concepts.html
|
||||||
|
:type image_mode: str
|
||||||
|
:return: The converted dataset.
|
||||||
|
:rtype: numpy.ndarray
|
||||||
|
"""
|
||||||
|
if len(dataset) > 0 and isinstance(dataset[0], str):
|
||||||
|
try:
|
||||||
|
dataset = np.array([get_image_from_path(
|
||||||
|
x, image_mode) for x in dataset])
|
||||||
|
except ValueError:
|
||||||
|
# if images of different sizes, try to convert one by one
|
||||||
|
jagged = np.empty(len(dataset), dtype=object)
|
||||||
|
for i, x in enumerate(dataset):
|
||||||
|
jagged[i] = get_image_from_path(x, image_mode)
|
||||||
|
dataset = jagged
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_images(dataset, image_mode, transformations=None):
|
||||||
|
"""Get the images from the dataset.
|
||||||
|
|
||||||
|
If transformations are provided as a callable, the images
|
||||||
|
are transformed. If transformations are provided as a string,
|
||||||
|
the images are retrieved from that column name in the test dataset.
|
||||||
|
|
||||||
|
:param dataset: The dataset to get the images from.
|
||||||
|
:type dataset: numpy.ndarray
|
||||||
|
:param image_mode: The mode to open the image in.
|
||||||
|
See pillow documentation for all modes:
|
||||||
|
https://pillow.readthedocs.io/en/stable/handbook/concepts.html
|
||||||
|
:type image_mode: str
|
||||||
|
:param transformations: The transformations to apply to the images.
|
||||||
|
:type transformations: torchvision.transforms
|
||||||
|
:return: The images.
|
||||||
|
:rtype: numpy.ndarray
|
||||||
|
"""
|
||||||
|
IMAGE = "image"
|
||||||
|
IMAGE_URL = "image_url"
|
||||||
|
|
||||||
|
column_names = dataset.columns
|
||||||
|
is_transformations_str = isinstance(transformations, str)
|
||||||
|
if is_transformations_str:
|
||||||
|
images = dataset[transformations]
|
||||||
|
else:
|
||||||
|
if IMAGE in column_names:
|
||||||
|
images = dataset[IMAGE]
|
||||||
|
elif IMAGE_URL in column_names:
|
||||||
|
images = dataset[IMAGE_URL]
|
||||||
|
else:
|
||||||
|
raise ValueError('No image column found in test data')
|
||||||
|
|
||||||
|
images = np.array(images.tolist())
|
||||||
|
converted_images = convert_images(images, image_mode)
|
||||||
|
|
||||||
|
if not is_transformations_str and transformations is not None:
|
||||||
|
converted_images = transformations(converted_images)
|
||||||
|
|
||||||
|
return converted_images
|
||||||
|
|
||||||
|
|
||||||
|
def load_fridge_object_detection_dataset_labels():
|
||||||
|
"""Loads the labels for the fridge object detection dataset.
|
||||||
|
|
||||||
|
return: list of labels
|
||||||
|
rtype: list
|
||||||
|
"""
|
||||||
|
|
||||||
|
src_images = "./data/odFridgeObjects/"
|
||||||
|
|
||||||
|
# Path to the annotations
|
||||||
|
annotations_folder = os.path.join(src_images, "annotations")
|
||||||
|
|
||||||
|
labels = []
|
||||||
|
label_dict = {'can': 1, 'carton': 2, 'milk_bottle': 3, 'water_bottle': 4}
|
||||||
|
|
||||||
|
# Read each annotation
|
||||||
|
for _, filename in enumerate(os.listdir(annotations_folder)):
|
||||||
|
if filename.endswith(".xml"):
|
||||||
|
print("Parsing " + os.path.join(src_images, filename))
|
||||||
|
|
||||||
|
root = ET.parse(
|
||||||
|
os.path.join(annotations_folder, filename)
|
||||||
|
).getroot()
|
||||||
|
|
||||||
|
# use if needed
|
||||||
|
# width = int(root.find("size/width").text)
|
||||||
|
# height = int(root.find("size/height").text)
|
||||||
|
|
||||||
|
image_labels = []
|
||||||
|
for object in root.findall("object"):
|
||||||
|
name = object.find("name").text
|
||||||
|
xmin = object.find("bndbox/xmin").text
|
||||||
|
ymin = object.find("bndbox/ymin").text
|
||||||
|
xmax = object.find("bndbox/xmax").text
|
||||||
|
ymax = object.find("bndbox/ymax").text
|
||||||
|
isCrowd = int(object.find("difficult").text)
|
||||||
|
image_labels.append([
|
||||||
|
label_dict[name], # label
|
||||||
|
float(xmin), # topX. To normalize, divide by width.
|
||||||
|
float(ymin), # topY. To normalize, divide by height.
|
||||||
|
float(xmax), # bottomX. To normalize, divide by width
|
||||||
|
float(ymax), # bottomY. To normalize, divide by height
|
||||||
|
int(isCrowd)
|
||||||
|
])
|
||||||
|
labels.append(image_labels)
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def load_fridge_object_detection_dataset():
|
||||||
|
"""Loads the fridge object detection dataset.
|
||||||
|
|
||||||
|
return: pandas dataframe with image paths and labels
|
||||||
|
rtype: pd.DataFrame
|
||||||
|
"""
|
||||||
|
# create data folder if it doesnt exist.
|
||||||
|
os.makedirs("data", exist_ok=True)
|
||||||
|
|
||||||
|
# download data
|
||||||
|
download_url = ("https://cvbp-secondary.z19.web.core.windows.net/" +
|
||||||
|
"datasets/object_detection/odFridgeObjects.zip")
|
||||||
|
data_file = "./odFridgeObjects.zip"
|
||||||
|
request_file.urlretrieve(download_url, filename=data_file)
|
||||||
|
|
||||||
|
# extract files
|
||||||
|
with ZipFile(data_file, "r") as zip:
|
||||||
|
print("extracting files...")
|
||||||
|
zip.extractall(path="./data")
|
||||||
|
print("done")
|
||||||
|
# delete zip file
|
||||||
|
os.remove(data_file)
|
||||||
|
|
||||||
|
labels = load_fridge_object_detection_dataset_labels()
|
||||||
|
|
||||||
|
# get all file names into a pandas dataframe with the labels
|
||||||
|
data = pd.DataFrame(columns=["image", "label"])
|
||||||
|
for i, file in enumerate(os.listdir("./data/odFridgeObjects/" + "images")):
|
||||||
|
image_path = "./data/odFridgeObjects/" + "images" + "/" + file
|
||||||
|
data = data.append({"image": image_path,
|
||||||
|
"label": labels[i]}, # folder
|
||||||
|
ignore_index=True)
|
||||||
|
|
||||||
|
return data
|
|
@ -3,9 +3,12 @@
|
||||||
|
|
||||||
"""Namespace for models."""
|
"""Namespace for models."""
|
||||||
|
|
||||||
from .model_utils import create_models_classification, create_models_regression
|
from .model_utils import (create_models_classification,
|
||||||
|
create_models_object_detection,
|
||||||
|
create_models_regression)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"create_models_regression",
|
"create_models_regression",
|
||||||
"create_models_classification"
|
"create_models_classification",
|
||||||
|
"create_models_object_detection"
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,6 +5,7 @@ from rai_test_utils.models.lightgbm import create_lightgbm_classifier
|
||||||
from rai_test_utils.models.sklearn import (
|
from rai_test_utils.models.sklearn import (
|
||||||
create_sklearn_logistic_regressor, create_sklearn_random_forest_classifier,
|
create_sklearn_logistic_regressor, create_sklearn_random_forest_classifier,
|
||||||
create_sklearn_random_forest_regressor, create_sklearn_svm_classifier)
|
create_sklearn_random_forest_regressor, create_sklearn_svm_classifier)
|
||||||
|
from rai_test_utils.models.torch import get_object_detection_fridge_model
|
||||||
from rai_test_utils.models.xgboost import create_xgboost_classifier
|
from rai_test_utils.models.xgboost import create_xgboost_classifier
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,3 +41,14 @@ def create_models_regression(X_train, y_train):
|
||||||
rf_model = create_sklearn_random_forest_regressor(X_train, y_train)
|
rf_model = create_sklearn_random_forest_regressor(X_train, y_train)
|
||||||
|
|
||||||
return [rf_model]
|
return [rf_model]
|
||||||
|
|
||||||
|
|
||||||
|
def create_models_object_detection():
|
||||||
|
"""Create a list of models for object detection.
|
||||||
|
|
||||||
|
:return: A list of models.
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
fridge_model = get_object_detection_fridge_model()
|
||||||
|
|
||||||
|
return [fridge_model]
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
# Copyright (c) Microsoft Corporation
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""Namespace for torch models."""
|
||||||
|
|
||||||
|
from .torch_model_utils import get_object_detection_fridge_model
|
||||||
|
|
||||||
|
__all__ = ["get_object_detection_fridge_model"]
|
|
@ -0,0 +1,80 @@
|
||||||
|
# Copyright (c) Microsoft Corporation
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import urllib.request as request_file
|
||||||
|
|
||||||
|
module_logger = logging.getLogger(__name__)
|
||||||
|
module_logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
|
except ImportError:
|
||||||
|
module_logger.debug(
|
||||||
|
'Could not import torch/torchvision, required for object detection.')
|
||||||
|
|
||||||
|
|
||||||
|
# download fine-tuned recycling model from url
|
||||||
|
def download_assets(filepath, force=False):
|
||||||
|
"""Download assets from url if not already downloaded.
|
||||||
|
|
||||||
|
:param filepath: Path to the file to download.
|
||||||
|
:type filepath: str
|
||||||
|
:param force: Whether to force download the file. Defaults to False.
|
||||||
|
:type force: bool, optional
|
||||||
|
:returns: Path to the downloaded file.
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
|
if force or not os.path.exists(filepath):
|
||||||
|
url = ("https://publictestdatasets.blob.core.windows.net" +
|
||||||
|
"/models/fastrcnn.pt")
|
||||||
|
request_file.urlretrieve(url, os.path.join(filepath))
|
||||||
|
else:
|
||||||
|
print('Found' + filepath)
|
||||||
|
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
|
||||||
|
def get_instance_segmentation_model(num_classes):
|
||||||
|
"""Get an instance segmentation model.
|
||||||
|
|
||||||
|
:param num_classes: Number of classes.
|
||||||
|
:type num_classes: int
|
||||||
|
:returns: Instance segmentation model.
|
||||||
|
:rtype: torchvision.models.detection.faster_rcnn.FasterRCNN
|
||||||
|
"""
|
||||||
|
# load an instance segmentation model pre-trained on COCO
|
||||||
|
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
|
||||||
|
pretrained=True
|
||||||
|
)
|
||||||
|
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||||
|
# replace the pre-trained head with a new one
|
||||||
|
model.roi_heads.box_predictor = FastRCNNPredictor(
|
||||||
|
in_features,
|
||||||
|
num_classes
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_object_detection_fridge_model():
|
||||||
|
"""Loads the fridge object detection model.
|
||||||
|
|
||||||
|
:returns: The fridge object detection model.
|
||||||
|
:rtype: torchvision.models.detection.faster_rcnn.FasterRCNN
|
||||||
|
"""
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
num_classes = 5
|
||||||
|
model = get_instance_segmentation_model(num_classes)
|
||||||
|
_ = download_assets('Recycling_finetuned_FastRCNN.pt')
|
||||||
|
model.load_state_dict(
|
||||||
|
torch.load('Recycling_finetuned_FastRCNN.pt',
|
||||||
|
map_location=device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
return model
|
|
@ -4,4 +4,5 @@ pytest-mock==3.6.1
|
||||||
|
|
||||||
requirements-parser==0.2.0
|
requirements-parser==0.2.0
|
||||||
|
|
||||||
pandas>=0.25.1,<2.0.0
|
pandas>=0.25.1,<2.0.0
|
||||||
|
ml-wrappers
|
|
@ -0,0 +1,3 @@
|
||||||
|
requests
|
||||||
|
Pillow>=10.0.0; python_version>"3.7" # due to breaking changes in v10.0.0 (https://pillow.readthedocs.io/en/latest/releasenotes/10.0.0.html)
|
||||||
|
Pillow<10.0.0; python_version<="3.7" # Pillow v10.0.0 is only available starting with Python 3.8
|
|
@ -15,6 +15,12 @@ with open("README.md", "r") as fh:
|
||||||
with open('requirements.txt') as f:
|
with open('requirements.txt') as f:
|
||||||
install_requires = [line.strip() for line in f]
|
install_requires = [line.strip() for line in f]
|
||||||
|
|
||||||
|
# Use requirements-object-detection.txt to set the install_requires
|
||||||
|
with open('requirements-object-detection.txt') as f:
|
||||||
|
extras_require = {
|
||||||
|
'object_detection': [line.strip() for line in f]
|
||||||
|
}
|
||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name=name, # noqa: F821
|
name=name, # noqa: F821
|
||||||
version=version, # noqa: F821
|
version=version, # noqa: F821
|
||||||
|
@ -27,6 +33,7 @@ setuptools.setup(
|
||||||
packages=setuptools.find_packages(),
|
packages=setuptools.find_packages(),
|
||||||
python_requires='>=3.6',
|
python_requires='>=3.6',
|
||||||
install_requires=install_requires,
|
install_requires=install_requires,
|
||||||
|
extras_require=extras_require,
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Programming Language :: Python :: 3.6",
|
"Programming Language :: Python :: 3.6",
|
||||||
"Programming Language :: Python :: 3.7",
|
"Programming Language :: Python :: 3.7",
|
||||||
|
|
|
@ -9,6 +9,8 @@ from rai_test_utils.datasets.tabular import (
|
||||||
create_energy_data, create_housing_data, create_iris_data, create_msx_data,
|
create_energy_data, create_housing_data, create_iris_data, create_msx_data,
|
||||||
create_multiclass_classification_dataset, create_reviews_data,
|
create_multiclass_classification_dataset, create_reviews_data,
|
||||||
create_simple_titanic_data, create_timeseries_data, create_wine_data)
|
create_simple_titanic_data, create_timeseries_data, create_wine_data)
|
||||||
|
from rai_test_utils.datasets.vision import (
|
||||||
|
get_images, load_fridge_object_detection_dataset)
|
||||||
|
|
||||||
|
|
||||||
class TestDataUtils:
|
class TestDataUtils:
|
||||||
|
@ -141,3 +143,19 @@ class TestDataUtils:
|
||||||
assert X_test is not None
|
assert X_test is not None
|
||||||
assert y_train is not None
|
assert y_train is not None
|
||||||
assert y_test is not None
|
assert y_test is not None
|
||||||
|
|
||||||
|
def test_create_fridge_data(self):
|
||||||
|
dataset = load_fridge_object_detection_dataset()
|
||||||
|
X_train = X_test = dataset[["image"]]
|
||||||
|
y_train = y_test = dataset[["label"]]
|
||||||
|
assert X_train is not None
|
||||||
|
assert X_test is not None
|
||||||
|
assert y_train is not None
|
||||||
|
assert y_test is not None
|
||||||
|
|
||||||
|
def test_get_images(self):
|
||||||
|
fridge_dataset = load_fridge_object_detection_dataset().iloc[:2]
|
||||||
|
images = get_images(fridge_dataset, "RGB", None)
|
||||||
|
assert len(images) == 2
|
||||||
|
assert images[0].shape == (666, 499, 3)
|
||||||
|
assert images[1].shape == (666, 499, 3)
|
||||||
|
|
|
@ -1,10 +1,16 @@
|
||||||
# Copyright (c) Microsoft Corporation
|
# Copyright (c) Microsoft Corporation
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from ml_wrappers import wrap_model
|
||||||
|
|
||||||
from rai_test_utils.datasets.tabular import (create_housing_data,
|
from rai_test_utils.datasets.tabular import (create_housing_data,
|
||||||
create_iris_data,
|
create_iris_data,
|
||||||
create_simple_titanic_data)
|
create_simple_titanic_data)
|
||||||
|
from rai_test_utils.datasets.vision import (
|
||||||
|
get_images, load_fridge_object_detection_dataset)
|
||||||
from rai_test_utils.models import (create_models_classification,
|
from rai_test_utils.models import (create_models_classification,
|
||||||
|
create_models_object_detection,
|
||||||
create_models_regression)
|
create_models_regression)
|
||||||
from rai_test_utils.models.sklearn import \
|
from rai_test_utils.models.sklearn import \
|
||||||
create_complex_classification_pipeline
|
create_complex_classification_pipeline
|
||||||
|
@ -32,3 +38,17 @@ class TestModelUtils:
|
||||||
pipeline = create_complex_classification_pipeline(
|
pipeline = create_complex_classification_pipeline(
|
||||||
X_train, y_train, num_feature_names, cat_feature_names)
|
X_train, y_train, num_feature_names, cat_feature_names)
|
||||||
assert pipeline.predict(X_test) is not None
|
assert pipeline.predict(X_test) is not None
|
||||||
|
|
||||||
|
def test_object_detection_models(self):
|
||||||
|
dataset = load_fridge_object_detection_dataset().iloc[:2]
|
||||||
|
|
||||||
|
X_train = dataset[["image"]]
|
||||||
|
classes = np.array(['can', 'carton', 'milk_bottle', 'water_bottle'])
|
||||||
|
|
||||||
|
model_list = create_models_object_detection()
|
||||||
|
for model in model_list:
|
||||||
|
dataset = get_images(X_train, "RGB", None)
|
||||||
|
wrapped_model = wrap_model(
|
||||||
|
model, dataset, "object_detection",
|
||||||
|
classes=classes)
|
||||||
|
assert wrapped_model.predict(dataset) is not None
|
||||||
|
|
Загрузка…
Ссылка в новой задаче