From f258b11bc1458c327dba2b875b92efb0a77fa065 Mon Sep 17 00:00:00 2001 From: Advitya Gemawat Date: Fri, 25 Aug 2023 22:56:24 -0400 Subject: [PATCH] OD model and data util functions in rai_test_utils (#2246) * OD model & data ckpt * torch dependency * removed 3.6 from PR gate * removed 3.6 from PR gate * import fixes * str format fix * image utils * model util function * dependency updates * added constants * unit tests * lint fixes * sorted imports * dependency update * install updates * gate tweak * gate tweak * gate tweak * gate tweak * gate tweak * gate tweak * torch install tweak * torch install tweak * torch install tweak * torch install tweak * torch install tweak * torch install tweak * updated gate test for torch installation on pip * updated gate test for torch installation on pip * updated gate test for torch installation on pip * updated gate test for torch installation on pip * removed torch installation * torch ckpt * torch conda install * conda setup * gate tweak * bash test * gate py + conda * bash support to py commands * bash support to py commands * bash support to py commands * bash support to py commands * bash support to py commands * added mkl support * added mkl support * disabled py 3.7 due to pip internal import error * raiutils dependency update * raiutils dependency update * numpy+mkl fixes * numpy+mkl fixes * test revert * gate update * revert dependencies * removed pip compile * removed pip sync * convert CI-python to use conda and install pytorch dependencies * convert CI-python to use conda and install pytorch dependencies * dependency update * auto lint fixes * added torch & torchvision versions for macos compat * version revert * disabled torch incompat test --------- Co-authored-by: Ilya Matiach --- .github/workflows/CI-python.yml | 9 + .../datasets/vision/__init__.py | 12 + .../vision/object_detection_data_utils.py | 214 ++++++++++++++++++ .../rai_test_utils/models/__init__.py | 7 +- .../rai_test_utils/models/model_utils.py | 12 + .../rai_test_utils/models/torch/__init__.py | 8 + .../models/torch/torch_model_utils.py | 80 +++++++ rai_test_utils/requirements-dev.txt | 3 +- .../requirements-object-detection.txt | 3 + rai_test_utils/setup.py | 7 + rai_test_utils/tests/test_data_utils.py | 18 ++ rai_test_utils/tests/test_model_utils.py | 20 ++ 12 files changed, 390 insertions(+), 3 deletions(-) create mode 100644 rai_test_utils/rai_test_utils/datasets/vision/__init__.py create mode 100644 rai_test_utils/rai_test_utils/datasets/vision/object_detection_data_utils.py create mode 100644 rai_test_utils/rai_test_utils/models/torch/__init__.py create mode 100644 rai_test_utils/rai_test_utils/models/torch/torch_model_utils.py create mode 100644 rai_test_utils/requirements-object-detection.txt diff --git a/.github/workflows/CI-python.yml b/.github/workflows/CI-python.yml index 0c99f53ba..8f0fa8594 100644 --- a/.github/workflows/CI-python.yml +++ b/.github/workflows/CI-python.yml @@ -27,6 +27,8 @@ jobs: pythonVersion: "3.7" - operatingSystem: windows-latest pythonVersion: "3.7" + - operatingSystem: macos-latest + pythonVersion: "3.8" runs-on: ${{ matrix.operatingSystem }} @@ -74,6 +76,13 @@ jobs: pip install -r requirements-dev.txt working-directory: ${{ matrix.packageDirectory }} + - if: ${{ matrix.packageDirectory == 'rai_test_utils' }} + name: Install package extras + shell: bash -l {0} + run: | + pip install -r requirements-object-detection.txt + working-directory: ${{ matrix.packageDirectory }} + - name: Install package shell: bash -l {0} run: | diff --git a/rai_test_utils/rai_test_utils/datasets/vision/__init__.py b/rai_test_utils/rai_test_utils/datasets/vision/__init__.py new file mode 100644 index 000000000..03b43f7a9 --- /dev/null +++ b/rai_test_utils/rai_test_utils/datasets/vision/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation +# Licensed under the MIT License. + +"""Namespace for vision datasets.""" + +from .object_detection_data_utils import (get_images, + load_fridge_object_detection_dataset) + +__all__ = [ + "get_images", + "load_fridge_object_detection_dataset" +] diff --git a/rai_test_utils/rai_test_utils/datasets/vision/object_detection_data_utils.py b/rai_test_utils/rai_test_utils/datasets/vision/object_detection_data_utils.py new file mode 100644 index 000000000..aab731f74 --- /dev/null +++ b/rai_test_utils/rai_test_utils/datasets/vision/object_detection_data_utils.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation +# Licensed under the MIT License. + +import os +import urllib.request as request_file +import xml.etree.ElementTree as ET +from io import BytesIO +from urllib.parse import urlparse +from zipfile import ZipFile + +import numpy as np +import pandas as pd +import requests +from PIL import Image +from requests.adapters import HTTPAdapter +from requests.packages.urllib3.util.retry import Retry + +# domain mapped session for reuse +_requests_sessions = {} + + +def _get_retry_session(url): + domain = urlparse(url.lower()).netloc + if domain in _requests_sessions: + return _requests_sessions[domain] + + session = requests.Session() + retries = Retry( + total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504] + ) + session.mount("http://", HTTPAdapter(max_retries=retries)) + session.mount("https://", HTTPAdapter(max_retries=retries)) + _requests_sessions[domain] = session + + return session + + +def get_image_from_path(image_path, image_mode): + """Get image from path. + + :param image_path: The path to the image. + :type image_path: str + :param image_mode: The mode to open the image in. + See pillow documentation for all modes: + https://pillow.readthedocs.io/en/stable/handbook/concepts.html + :type image_mode: str + :return: The image as a numpy array. + :rtype: numpy.ndarray + """ + image_open_pointer = image_path + if image_path.startswith("http://") or image_path.startswith("https://"): + response = _get_retry_session(image_path).get(image_path) + image_open_pointer = BytesIO(response.content) + with Image.open(image_open_pointer) as im: + if image_mode is not None: + im = im.convert(image_mode) + image_array = np.asarray(im) + return image_array + + +def convert_images(dataset, image_mode): + """Converts the images to the format required by the model. + + If the images are base64 encoded, they are decoded and converted to + numpy arrays. If the images are already numpy arrays, they are + returned as is. + + :param dataset: The dataset to convert. + :type dataset: numpy.ndarray + :param image_mode: The mode to open the image in. + See pillow documentation for all modes: + https://pillow.readthedocs.io/en/stable/handbook/concepts.html + :type image_mode: str + :return: The converted dataset. + :rtype: numpy.ndarray + """ + if len(dataset) > 0 and isinstance(dataset[0], str): + try: + dataset = np.array([get_image_from_path( + x, image_mode) for x in dataset]) + except ValueError: + # if images of different sizes, try to convert one by one + jagged = np.empty(len(dataset), dtype=object) + for i, x in enumerate(dataset): + jagged[i] = get_image_from_path(x, image_mode) + dataset = jagged + return dataset + + +def get_images(dataset, image_mode, transformations=None): + """Get the images from the dataset. + + If transformations are provided as a callable, the images + are transformed. If transformations are provided as a string, + the images are retrieved from that column name in the test dataset. + + :param dataset: The dataset to get the images from. + :type dataset: numpy.ndarray + :param image_mode: The mode to open the image in. + See pillow documentation for all modes: + https://pillow.readthedocs.io/en/stable/handbook/concepts.html + :type image_mode: str + :param transformations: The transformations to apply to the images. + :type transformations: torchvision.transforms + :return: The images. + :rtype: numpy.ndarray + """ + IMAGE = "image" + IMAGE_URL = "image_url" + + column_names = dataset.columns + is_transformations_str = isinstance(transformations, str) + if is_transformations_str: + images = dataset[transformations] + else: + if IMAGE in column_names: + images = dataset[IMAGE] + elif IMAGE_URL in column_names: + images = dataset[IMAGE_URL] + else: + raise ValueError('No image column found in test data') + + images = np.array(images.tolist()) + converted_images = convert_images(images, image_mode) + + if not is_transformations_str and transformations is not None: + converted_images = transformations(converted_images) + + return converted_images + + +def load_fridge_object_detection_dataset_labels(): + """Loads the labels for the fridge object detection dataset. + + return: list of labels + rtype: list + """ + + src_images = "./data/odFridgeObjects/" + + # Path to the annotations + annotations_folder = os.path.join(src_images, "annotations") + + labels = [] + label_dict = {'can': 1, 'carton': 2, 'milk_bottle': 3, 'water_bottle': 4} + + # Read each annotation + for _, filename in enumerate(os.listdir(annotations_folder)): + if filename.endswith(".xml"): + print("Parsing " + os.path.join(src_images, filename)) + + root = ET.parse( + os.path.join(annotations_folder, filename) + ).getroot() + + # use if needed + # width = int(root.find("size/width").text) + # height = int(root.find("size/height").text) + + image_labels = [] + for object in root.findall("object"): + name = object.find("name").text + xmin = object.find("bndbox/xmin").text + ymin = object.find("bndbox/ymin").text + xmax = object.find("bndbox/xmax").text + ymax = object.find("bndbox/ymax").text + isCrowd = int(object.find("difficult").text) + image_labels.append([ + label_dict[name], # label + float(xmin), # topX. To normalize, divide by width. + float(ymin), # topY. To normalize, divide by height. + float(xmax), # bottomX. To normalize, divide by width + float(ymax), # bottomY. To normalize, divide by height + int(isCrowd) + ]) + labels.append(image_labels) + + return labels + + +def load_fridge_object_detection_dataset(): + """Loads the fridge object detection dataset. + + return: pandas dataframe with image paths and labels + rtype: pd.DataFrame + """ + # create data folder if it doesnt exist. + os.makedirs("data", exist_ok=True) + + # download data + download_url = ("https://cvbp-secondary.z19.web.core.windows.net/" + + "datasets/object_detection/odFridgeObjects.zip") + data_file = "./odFridgeObjects.zip" + request_file.urlretrieve(download_url, filename=data_file) + + # extract files + with ZipFile(data_file, "r") as zip: + print("extracting files...") + zip.extractall(path="./data") + print("done") + # delete zip file + os.remove(data_file) + + labels = load_fridge_object_detection_dataset_labels() + + # get all file names into a pandas dataframe with the labels + data = pd.DataFrame(columns=["image", "label"]) + for i, file in enumerate(os.listdir("./data/odFridgeObjects/" + "images")): + image_path = "./data/odFridgeObjects/" + "images" + "/" + file + data = data.append({"image": image_path, + "label": labels[i]}, # folder + ignore_index=True) + + return data diff --git a/rai_test_utils/rai_test_utils/models/__init__.py b/rai_test_utils/rai_test_utils/models/__init__.py index ee24840a4..f196a2408 100644 --- a/rai_test_utils/rai_test_utils/models/__init__.py +++ b/rai_test_utils/rai_test_utils/models/__init__.py @@ -3,9 +3,12 @@ """Namespace for models.""" -from .model_utils import create_models_classification, create_models_regression +from .model_utils import (create_models_classification, + create_models_object_detection, + create_models_regression) __all__ = [ "create_models_regression", - "create_models_classification" + "create_models_classification", + "create_models_object_detection" ] diff --git a/rai_test_utils/rai_test_utils/models/model_utils.py b/rai_test_utils/rai_test_utils/models/model_utils.py index 13b96d3f6..06f788ef7 100644 --- a/rai_test_utils/rai_test_utils/models/model_utils.py +++ b/rai_test_utils/rai_test_utils/models/model_utils.py @@ -5,6 +5,7 @@ from rai_test_utils.models.lightgbm import create_lightgbm_classifier from rai_test_utils.models.sklearn import ( create_sklearn_logistic_regressor, create_sklearn_random_forest_classifier, create_sklearn_random_forest_regressor, create_sklearn_svm_classifier) +from rai_test_utils.models.torch import get_object_detection_fridge_model from rai_test_utils.models.xgboost import create_xgboost_classifier @@ -40,3 +41,14 @@ def create_models_regression(X_train, y_train): rf_model = create_sklearn_random_forest_regressor(X_train, y_train) return [rf_model] + + +def create_models_object_detection(): + """Create a list of models for object detection. + + :return: A list of models. + :rtype: list + """ + fridge_model = get_object_detection_fridge_model() + + return [fridge_model] diff --git a/rai_test_utils/rai_test_utils/models/torch/__init__.py b/rai_test_utils/rai_test_utils/models/torch/__init__.py new file mode 100644 index 000000000..a41a2bc42 --- /dev/null +++ b/rai_test_utils/rai_test_utils/models/torch/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation +# Licensed under the MIT License. + +"""Namespace for torch models.""" + +from .torch_model_utils import get_object_detection_fridge_model + +__all__ = ["get_object_detection_fridge_model"] diff --git a/rai_test_utils/rai_test_utils/models/torch/torch_model_utils.py b/rai_test_utils/rai_test_utils/models/torch/torch_model_utils.py new file mode 100644 index 000000000..15762cbf4 --- /dev/null +++ b/rai_test_utils/rai_test_utils/models/torch/torch_model_utils.py @@ -0,0 +1,80 @@ +# Copyright (c) Microsoft Corporation +# Licensed under the MIT License. + +import logging +import os +import urllib.request as request_file + +module_logger = logging.getLogger(__name__) +module_logger.setLevel(logging.INFO) + +try: + import torch + import torchvision + from torchvision.models.detection.faster_rcnn import FastRCNNPredictor +except ImportError: + module_logger.debug( + 'Could not import torch/torchvision, required for object detection.') + + +# download fine-tuned recycling model from url +def download_assets(filepath, force=False): + """Download assets from url if not already downloaded. + + :param filepath: Path to the file to download. + :type filepath: str + :param force: Whether to force download the file. Defaults to False. + :type force: bool, optional + :returns: Path to the downloaded file. + :rtype: str + """ + if force or not os.path.exists(filepath): + url = ("https://publictestdatasets.blob.core.windows.net" + + "/models/fastrcnn.pt") + request_file.urlretrieve(url, os.path.join(filepath)) + else: + print('Found' + filepath) + + return filepath + + +def get_instance_segmentation_model(num_classes): + """Get an instance segmentation model. + + :param num_classes: Number of classes. + :type num_classes: int + :returns: Instance segmentation model. + :rtype: torchvision.models.detection.faster_rcnn.FasterRCNN + """ + # load an instance segmentation model pre-trained on COCO + model = torchvision.models.detection.fasterrcnn_resnet50_fpn( + pretrained=True + ) + in_features = model.roi_heads.box_predictor.cls_score.in_features + # replace the pre-trained head with a new one + model.roi_heads.box_predictor = FastRCNNPredictor( + in_features, + num_classes + ) + return model + + +def get_object_detection_fridge_model(): + """Loads the fridge object detection model. + + :returns: The fridge object detection model. + :rtype: torchvision.models.detection.faster_rcnn.FasterRCNN + """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + num_classes = 5 + model = get_instance_segmentation_model(num_classes) + _ = download_assets('Recycling_finetuned_FastRCNN.pt') + model.load_state_dict( + torch.load('Recycling_finetuned_FastRCNN.pt', + map_location=device + ) + ) + + model.to(device) + return model diff --git a/rai_test_utils/requirements-dev.txt b/rai_test_utils/requirements-dev.txt index c35f2d855..8833b3495 100644 --- a/rai_test_utils/requirements-dev.txt +++ b/rai_test_utils/requirements-dev.txt @@ -4,4 +4,5 @@ pytest-mock==3.6.1 requirements-parser==0.2.0 -pandas>=0.25.1,<2.0.0 \ No newline at end of file +pandas>=0.25.1,<2.0.0 +ml-wrappers \ No newline at end of file diff --git a/rai_test_utils/requirements-object-detection.txt b/rai_test_utils/requirements-object-detection.txt new file mode 100644 index 000000000..1e1cc4513 --- /dev/null +++ b/rai_test_utils/requirements-object-detection.txt @@ -0,0 +1,3 @@ +requests +Pillow>=10.0.0; python_version>"3.7" # due to breaking changes in v10.0.0 (https://pillow.readthedocs.io/en/latest/releasenotes/10.0.0.html) +Pillow<10.0.0; python_version<="3.7" # Pillow v10.0.0 is only available starting with Python 3.8 \ No newline at end of file diff --git a/rai_test_utils/setup.py b/rai_test_utils/setup.py index 5feef7038..f0a7a96c1 100644 --- a/rai_test_utils/setup.py +++ b/rai_test_utils/setup.py @@ -15,6 +15,12 @@ with open("README.md", "r") as fh: with open('requirements.txt') as f: install_requires = [line.strip() for line in f] +# Use requirements-object-detection.txt to set the install_requires +with open('requirements-object-detection.txt') as f: + extras_require = { + 'object_detection': [line.strip() for line in f] + } + setuptools.setup( name=name, # noqa: F821 version=version, # noqa: F821 @@ -27,6 +33,7 @@ setuptools.setup( packages=setuptools.find_packages(), python_requires='>=3.6', install_requires=install_requires, + extras_require=extras_require, classifiers=[ "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", diff --git a/rai_test_utils/tests/test_data_utils.py b/rai_test_utils/tests/test_data_utils.py index c764f5098..621fc658e 100644 --- a/rai_test_utils/tests/test_data_utils.py +++ b/rai_test_utils/tests/test_data_utils.py @@ -9,6 +9,8 @@ from rai_test_utils.datasets.tabular import ( create_energy_data, create_housing_data, create_iris_data, create_msx_data, create_multiclass_classification_dataset, create_reviews_data, create_simple_titanic_data, create_timeseries_data, create_wine_data) +from rai_test_utils.datasets.vision import ( + get_images, load_fridge_object_detection_dataset) class TestDataUtils: @@ -141,3 +143,19 @@ class TestDataUtils: assert X_test is not None assert y_train is not None assert y_test is not None + + def test_create_fridge_data(self): + dataset = load_fridge_object_detection_dataset() + X_train = X_test = dataset[["image"]] + y_train = y_test = dataset[["label"]] + assert X_train is not None + assert X_test is not None + assert y_train is not None + assert y_test is not None + + def test_get_images(self): + fridge_dataset = load_fridge_object_detection_dataset().iloc[:2] + images = get_images(fridge_dataset, "RGB", None) + assert len(images) == 2 + assert images[0].shape == (666, 499, 3) + assert images[1].shape == (666, 499, 3) diff --git a/rai_test_utils/tests/test_model_utils.py b/rai_test_utils/tests/test_model_utils.py index 37011d09a..a1b215d8e 100644 --- a/rai_test_utils/tests/test_model_utils.py +++ b/rai_test_utils/tests/test_model_utils.py @@ -1,10 +1,16 @@ # Copyright (c) Microsoft Corporation # Licensed under the MIT License. +import numpy as np +from ml_wrappers import wrap_model + from rai_test_utils.datasets.tabular import (create_housing_data, create_iris_data, create_simple_titanic_data) +from rai_test_utils.datasets.vision import ( + get_images, load_fridge_object_detection_dataset) from rai_test_utils.models import (create_models_classification, + create_models_object_detection, create_models_regression) from rai_test_utils.models.sklearn import \ create_complex_classification_pipeline @@ -32,3 +38,17 @@ class TestModelUtils: pipeline = create_complex_classification_pipeline( X_train, y_train, num_feature_names, cat_feature_names) assert pipeline.predict(X_test) is not None + + def test_object_detection_models(self): + dataset = load_fridge_object_detection_dataset().iloc[:2] + + X_train = dataset[["image"]] + classes = np.array(['can', 'carton', 'milk_bottle', 'water_bottle']) + + model_list = create_models_object_detection() + for model in model_list: + dataset = get_images(X_train, "RGB", None) + wrapped_model = wrap_model( + model, dataset, "object_detection", + classes=classes) + assert wrapped_model.predict(dataset) is not None