* add containers for onnx models

* add tvm_installed, initial work on topology

* add containers
add tvm backend to supported
add few tests

* fix type error in TVM
tree_trav and perf_tree_trav now work

* Add TVM_MAX_FUSE_DEPTH option
Add BATCH_SIZE option
Tree trav generate indexes based on batch size (if available)
TVM takes the max fuse detph configuration if set
This commit is contained in:
Matteo Interlandi 2020-11-03 13:21:02 -08:00 коммит произвёл GitHub
Родитель db65391556
Коммит 409c09a937
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
35 изменённых файлов: 1597 добавлений и 320 удалений

83
.github/workflows/pythonapp.yml поставляемый
Просмотреть файл

@ -27,32 +27,29 @@ jobs:
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
# PyTorch for Mac has different pip syntax wrt Win and Linux.
# PyTorch stop supporting python 3.5 from version 1.6.
# The following cases address the situations above.
- name: Install pytorch 1.5.1 if python 3.5 (mac) - name: Install pytorch 1.5.1 if python 3.5 (mac)
if: ${{ matrix.python-version == '3.5' && matrix.os == 'macos-latest' }} if: ${{ matrix.python-version == '3.5' && matrix.os == 'macos-latest' }}
run: | run: pip install torch==1.5.1
pip install torch==1.5.1
- name: Install pytorch 1.7.0 if python 3.5 (mac) - name: Install pytorch 1.7.0 if python 3.5 (mac)
if: ${{ matrix.python-version != '3.5' && matrix.os == 'macos-latest' }} if: ${{ matrix.python-version != '3.5' && matrix.os == 'macos-latest' }}
run: | run: pip install torch==1.7.0
pip install torch==1.7.0
- name: Install pytorch 1.5.1+cpu if python 3.5 (not mac) - name: Install pytorch 1.5.1+cpu if python 3.5 (not mac)
if: ${{ matrix.python-version == '3.5' && matrix.os != 'macos-latest' }} if: ${{ matrix.python-version == '3.5' && matrix.os != 'macos-latest' }}
run: | run: pip install torch==1.5.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.5.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install pytorch 1.7.0+cpu if python > 3.5 (not mac) - name: Install pytorch 1.7.0+cpu if python > 3.5 (not mac)
if: ${{ matrix.python-version != '3.5' && matrix.os != 'macos-latest' }} if: ${{ matrix.python-version != '3.5' && matrix.os != 'macos-latest' }}
run: | run: pip install torch==1.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install basic dependencies - name: Install basic dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[tests] -f https://download.pytorch.org/whl/torch_stable.html pip install .[tests] -f https://download.pytorch.org/whl/torch_stable.html
- name: Run basic tests without extra - name: Run basic tests without extra
run: | run: pytest
pytest
- name: Coverage on basic tests without extra - name: Coverage on basic tests without extra
run: | run: coverage run -a -m pytest tests/test_no_extra_install.py
coverage run -a -m pytest tests/test_no_extra_install.py
- name: If mac, install libomp to facilitate lgbm install - name: If mac, install libomp to facilitate lgbm install
if: matrix.os == 'macOS-latest' if: matrix.os == 'macOS-latest'
run: | run: |
@ -67,17 +64,70 @@ jobs:
run: | run: |
pip install .[extra,onnx,sparkml] pip install .[extra,onnx,sparkml]
pip install pandas pip install pandas
- uses: actions/cache@v1
# TVM takes forever, we try to cache it.
if: ${{ matrix.python-version != '3.5' && matrix.os != 'windows-2019' }}
id: cache
env:
CACHE_NUMBER: 1
with:
path: ../../../incubator-tvm
key: ${{ runner.os }}-${{ env.CACHE_NUMBER }}-tvm-0.7
# Getting TVM requires: 1) fetching TVM from github, 2) get LLVM, 3) cmake, 4) make, 5) install python dependecy.
# 1 to 4 will be retrieved from the cache.
# The pipeline only works for Unix systems. For windows we will have to compile LLVM from source which is a no go.
- name: Fetch and prepare TVM for compilation
if: ${{ steps.cache.outputs.cache-hit != 'true' && matrix.python-version != '3.5' && matrix.os != 'windows-2019' }}
run: |
cd ~/
git clone https://github.com/apache/incubator-tvm.git
cd incubator-tvm
git checkout tags/v0.7.0
git submodule update --recursive --init
cmake -E make_directory build
- name: Get LLVM on Linux
if: ${{ steps.cache.outputs.cache-hit != 'true' && matrix.python-version != '3.5' && matrix.os == 'ubuntu-latest' }}
working-directory: ../../../incubator-tvm
run: |
wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/clang+llvm-10.0.0-x86_64-linux-gnu-ubuntu-18.04.tar.xz
tar -xf clang+llvm-10.0.0-x86_64-linux-gnu-ubuntu-18.04.tar.xz && mv clang+llvm-10.0.0-x86_64-linux-gnu-ubuntu-18.04 llvm
- name: Get LLVM on Mac
if: ${{ steps.cache.outputs.cache-hit != 'true' && matrix.python-version != '3.5' && matrix.os == 'macos-latest' }}
working-directory: ../../../incubator-tvm
run: |
wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/clang+llvm-10.0.0-x86_64-apple-darwin.tar.xz
tar -xf clang+llvm-10.0.0-x86_64-apple-darwin.tar.xz && mv clang+llvm-10.0.0-x86_64-apple-darwin llvm
- name: CMake TVM
if: ${{ steps.cache.outputs.cache-hit != 'true' && matrix.python-version != '3.5' && matrix.os != 'windows-2019' }}
working-directory: ../../../incubator-tvm/build
run: >-
cmake
"-DUSE_RPC=ON"
"-DUSE_GRAPH_RUNTIME=ON"
"-DUSE_LLVM=../llvm/bin/llvm-config"
..
- name: Build TVM
if: ${{ steps.cache.outputs.cache-hit != 'true' && matrix.python-version != '3.5' && matrix.os != 'windows-2019' }}
working-directory: ../../../incubator-tvm/build
run: |
make -j3
- name: Install python TVM
if: ${{ matrix.python-version != '3.5' && matrix.os != 'windows-2019' }}
working-directory: ../../../incubator-tvm/python
run: |
python setup.py install
- name: Lint with flake8 - name: Lint with flake8
run: | run: |
# stop the build if there are Python syntax errors or undefined names # stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# The GitHub editor is 127 chars wide # The GitHub editor is 127 chars wide
flake8 . --count --max-complexity=10 --max-line-length=127 --statistics flake8 . --count --max-complexity=10 --max-line-length=127 --statistics
# We don't run pytest for Linux py3.7 since we do coverage for that case.
- name: Test with pytest - name: Test with pytest
run: | if: ${{ matrix.python-version == '3.7' && matrix.os != 'ubuntu-latest' }}
pytest run: pytest
# Run and push coverage only for one of the runs (Linux py3.7).
- name: Coverage - name: Coverage
# run and push coverage only on one of the runs
if: ${{ matrix.python-version == '3.7' && matrix.os == 'ubuntu-latest' }} if: ${{ matrix.python-version == '3.7' && matrix.os == 'ubuntu-latest' }}
run: | run: |
coverage run -a -m pytest tests coverage run -a -m pytest tests
@ -88,7 +138,8 @@ jobs:
with: with:
file: ./coverage.xml file: ./coverage.xml
flags: unittests flags: unittests
- name: Generate Documentation # for some awful reason, this only works with older torch???? # Compile and push documentation only for one of the runs (Linux py3.7).
- name: Generate Documentation
if: ${{ matrix.python-version == '3.7' && matrix.os == 'ubuntu-latest' }} if: ${{ matrix.python-version == '3.7' && matrix.os == 'ubuntu-latest' }}
run: | run: |
make sphinx-site -C website/ make sphinx-site -C website/

Просмотреть файл

@ -16,8 +16,8 @@
## Introduction ## Introduction
*Hummingbird* is a library for compiling trained traditional ML models into tensor computations. *Hummingbird* allows users to seamlessly leverage neural network frameworks (such as [PyTorch](https://pytorch.org/)) to accelerate traditional ML models. Thanks to *Hummingbird*, users can benefit from: (1) all the current and future optimizations implemented in neural network frameworks; (2) native hardware acceleration; (3) having a unique platform to support for both traditional and neural network models; and have all of this (4) without having to re-engineer their models. *Hummingbird* is a library for compiling trained traditional ML models into tensor computations. *Hummingbird* allows users to seamlessly leverage neural network frameworks (such as [PyTorch](https://pytorch.org/)) to accelerate traditional ML models. Thanks to *Hummingbird*, users can benefit from: (1) all the current and future optimizations implemented in neural network frameworks; (2) native hardware acceleration; (3) having a unique platform to support for both traditional and neural network models; and have all of this (4) without having to re-engineer their models.
Currently, you can use *Hummingbird* to convert your trained traditional ML models into [PyTorch](https://pytorch.org/), [TorchScript](https://pytorch.org/docs/stable/jit.html), and [ONNX](https://onnx.ai/). *Hummingbird* [supports](https://github.com/microsoft/hummingbird/wiki/Supported-Operators) a variety of ML models and featurizers. These models include Currently, you can use *Hummingbird* to convert your trained traditional ML models into [PyTorch](https://pytorch.org/), [TorchScript](https://pytorch.org/docs/stable/jit.html), [ONNX](https://onnx.ai/), and [TVM](https://docs.tvm.ai/)). *Hummingbird* [supports](https://github.com/microsoft/hummingbird/wiki/Supported-Operators) a variety of ML models and featurizers. These models include
[scikit-learn](https://scikit-learn.org/stable/) Decision Trees and Random Forest, and also [LightGBM](https://github.com/Microsoft/LightGBM) and [XGBoost](https://github.com/dmlc/xgboost) Classifiers/Regressors. Support for other neural network backends (e.g., [TVM](https://docs.tvm.ai/)) and models is on our [roadmap](https://github.com/microsoft/hummingbird/wiki/Roadmap-for-Upcoming-Features-and-Support). [scikit-learn](https://scikit-learn.org/stable/) Decision Trees and Random Forest, and also [LightGBM](https://github.com/Microsoft/LightGBM) and [XGBoost](https://github.com/dmlc/xgboost) Classifiers/Regressors. Support for other neural network backends and models is on our [roadmap](https://github.com/microsoft/hummingbird/wiki/Roadmap-for-Upcoming-Features-and-Support).
Hummingbird also provides a convenient uniform "inference" API following the Sklearn API. This allows swapping Sklearn models with Hummingbird-generated ones without having to change the inference code. Hummingbird also provides a convenient uniform "inference" API following the Sklearn API. This allows swapping Sklearn models with Hummingbird-generated ones without having to change the inference code.

Просмотреть файл

@ -63,11 +63,7 @@ import benchmarks.operators.train as train
import benchmarks.operators.score as score import benchmarks.operators.score as score
from benchmarks.datasets import prepare_dataset, LearningTask from benchmarks.datasets import prepare_dataset, LearningTask
from hummingbird.ml._utils import ( from hummingbird.ml._utils import sklearn_installed, onnx_ml_tools_installed, onnx_runtime_installed, tvm_installed
sklearn_installed,
onnx_ml_tools_installed,
onnx_runtime_installed,
)
ROOT_PATH = Path(__file__).absolute().parent.parent.parent ROOT_PATH = Path(__file__).absolute().parent.parent.parent
@ -106,7 +102,6 @@ def get_number_processors(args):
def print_sys_info(args): def print_sys_info(args):
import onnxruntime
import sklearn import sklearn
import torch import torch
@ -114,7 +109,20 @@ def print_sys_info(args):
print("OS : %s" % sys.platform) print("OS : %s" % sys.platform)
print("Sklearn : %s" % sklearn.__version__) print("Sklearn : %s" % sklearn.__version__)
print("PyTorch : %s" % torch.__version__) print("PyTorch : %s" % torch.__version__)
print("ORT : %s" % onnxruntime.__version__)
# Optional imports
try:
import onnxruntime
print("ORT : %s" % onnxruntime.__version__)
except ImportError:
pass
try:
import tvm
print("TVM : %s" % tvm.__version__)
except ImportError:
pass
if args.gpu: if args.gpu:
print("Running on GPU") print("Running on GPU")
@ -229,7 +237,19 @@ def benchmark(args, dataset_folder, model_folder, dataset):
args.operator = op args.operator = op
if args.backend == "all": if args.backend == "all":
args.backend = "onnx-ml,hb-torchscript,hb-onnx" args.backend = "onnx-ml,hb-pytorch,hb-torchscript,hb-onnx"
if "hb-tvm" in args.backend:
assert (
tvm_installed
), "To run benchmark with TVM you need to have TVM installed. Either install TVM or remove it from the backends."
if "hb-onnx" in args.backend:
assert (
onnx_runtime_installed
), "To run benchmark with ONNX you need to have ONNX runtime installed. Either install ONNX runtime or remove ONNX from the backends."
if "onnx-ml" in args.backend:
assert (
onnx_runtime_installed and onnx_ml_tools_installed
), "To run benchmark with ONNX-ML you need to have ONNX runtime and ONNXMLTOOLS installed. Either install ONNX runtime and ONNXMLTOOLS or remove ONNX-ML from the backends."
for backend in args.backend.split(","): for backend in args.backend.split(","):
print("Running '%s' ..." % backend) print("Running '%s' ..." % backend)
scorer = score.ScoreBackend.create(backend) scorer = score.ScoreBackend.create(backend)

Просмотреть файл

@ -44,6 +44,8 @@ class ScoreBackend(ABC):
return HBBackend("torch") return HBBackend("torch")
if name == "hb-torchscript": if name == "hb-torchscript":
return HBBackend("torch.jit") return HBBackend("torch.jit")
if name == "hb-tvm":
return HBBackend("tvm")
if name == "hb-onnx": if name == "hb-onnx":
return HBBackend("onnx") return HBBackend("onnx")
if name == "onnx-ml": if name == "onnx-ml":
@ -118,7 +120,7 @@ class ScoreBackend(ABC):
class HBBackend(ScoreBackend): class HBBackend(ScoreBackend):
def __init__(self, backend): def __init__(self, backend):
super().__init__() super(HBBackend, self).__init__()
self.backend = backend self.backend = backend
def convert(self, model, data, args, model_name): def convert(self, model, data, args, model_name):

Просмотреть файл

@ -32,6 +32,21 @@ def print_sys_info(args):
print("OS : %s" % sys.platform) print("OS : %s" % sys.platform)
print("Sklearn: %s" % sklearn.__version__) print("Sklearn: %s" % sklearn.__version__)
print("Torch: %s" % torch.__version__) print("Torch: %s" % torch.__version__)
# Optional imports
try:
import onnxruntime
print("ORT : %s" % onnxruntime.__version__)
except ImportError:
pass
try:
import tvm
print("TVM : %s" % tvm.__version__)
except ImportError:
pass
if args.gpu: if args.gpu:
print("Running on GPU") print("Running on GPU")
else: else:

Просмотреть файл

@ -15,7 +15,7 @@ import hummingbird.ml
class ScoreBackend(ABC): class ScoreBackend(ABC):
@staticmethod @staticmethod
def create(name): def create(name):
if name in ["torch", "torch.jit", "onnx"]: if name in ["torch", "torch.jit", "tvm", "onnx"]:
return HBBackend(name) return HBBackend(name)
raise ValueError("Unknown backend: " + name) raise ValueError("Unknown backend: " + name)

Просмотреть файл

@ -56,6 +56,7 @@ from hummingbird.ml._utils import (
sklearn_installed, sklearn_installed,
onnx_ml_tools_installed, onnx_ml_tools_installed,
onnx_runtime_installed, onnx_runtime_installed,
tvm_installed,
) )
@ -88,12 +89,19 @@ def print_sys_info(args):
print("Sklearn : %s" % sklearn.__version__) print("Sklearn : %s" % sklearn.__version__)
print("PyTorch : %s" % torch.__version__) print("PyTorch : %s" % torch.__version__)
# Optional imports
try: try:
import onnxruntime import onnxruntime
print("ORT : %s" % onnxruntime.__version__) print("ORT : %s" % onnxruntime.__version__)
except ImportError: except ImportError:
pass pass
try:
import tvm
print("TVM : %s" % tvm.__version__)
except ImportError:
pass
if args.gpu: if args.gpu:
print("Running on GPU") print("Running on GPU")
@ -235,7 +243,19 @@ def benchmark(args, dataset_folder, model_folder, dataset):
args.operator = op args.operator = op
if args.backend == "all": if args.backend == "all":
args.backend = "onnx-ml,hb-pytorch,hb-torchscript,hb-onnx" args.backend = "onnx-ml,hb-pytorch,hb-torchscript,hb-onnx,hb-tvm"
if "hb-tvm" in args.backend:
assert (
tvm_installed
), "To run benchmark with TVM you need to have TVM installed. Either install TVM or remove it from the backends."
if "hb-onnx" in args.backend:
assert (
onnx_runtime_installed
), "To run benchmark with ONNX you need to have ONNX runtime installed. Either install ONNX runtime or remove ONNX from the backends."
if "onnx-ml" in args.backend:
assert (
onnx_runtime_installed and onnx_ml_tools_installed
), "To run benchmark with ONNX-ML you need to have ONNX runtime and ONNXMLTOOLS installed. Either install ONNX runtime and ONNXMLTOOLS or remove ONNX-ML from the backends."
for backend in args.backend.split(","): for backend in args.backend.split(","):
print("Running '%s' ..." % backend) print("Running '%s' ..." % backend)
scorer = score.ScoreBackend.create(backend) scorer = score.ScoreBackend.create(backend)
@ -302,6 +322,5 @@ if __name__ == "__main__":
assert xgboost_installed, "benchmark requires XGBoost" assert xgboost_installed, "benchmark requires XGBoost"
assert lightgbm_installed, "benchmark requires LightGBM" assert lightgbm_installed, "benchmark requires LightGBM"
assert sklearn_installed, "benchmark requires sklearn" assert sklearn_installed, "benchmark requires sklearn"
assert onnx_ml_tools_installed and onnx_runtime_installed, "benchmark requires ORT and ONNXMLTOOLS"
main() main()

Просмотреть файл

@ -41,6 +41,8 @@ class ScoreBackend(ABC):
return HBBackend("torch") return HBBackend("torch")
if name == "hb-torchscript": if name == "hb-torchscript":
return HBBackend("torch.jit") return HBBackend("torch.jit")
if name == "hb-tvm":
return HBBackend("tvm")
if name == "hb-onnx": if name == "hb-onnx":
return HBBackend("onnx") return HBBackend("onnx")
if name == "onnx-ml": if name == "onnx-ml":
@ -94,7 +96,7 @@ class ScoreBackend(ABC):
class HBBackend(ScoreBackend): class HBBackend(ScoreBackend):
def __init__(self, backend): def __init__(self, backend):
super().__init__() super(HBBackend, self).__init__()
self.backend = backend self.backend = backend
def convert(self, model, data, args, model_name): def convert(self, model, data, args, model_name):

Просмотреть файл

@ -11,13 +11,14 @@ In Hummingbird we use two types of containers:
- containers for output models (e.g., `SklearnContainer`) used to surface output models as unified API format. - containers for output models (e.g., `SklearnContainer`) used to surface output models as unified API format.
""" """
from abc import ABC from abc import ABC, abstractmethod
import os
import numpy as np import numpy as np
from onnxconverter_common.container import CommonSklearnModelContainer from onnxconverter_common.container import CommonSklearnModelContainer
import torch import torch
from hummingbird.ml.operator_converters import constants from hummingbird.ml.operator_converters import constants
from hummingbird.ml._utils import onnx_runtime_installed, pandas_installed, _get_device from hummingbird.ml._utils import onnx_runtime_installed, tvm_installed, pandas_installed, _get_device
if pandas_installed(): if pandas_installed():
from pandas import DataFrame from pandas import DataFrame
@ -44,7 +45,8 @@ class CommonSparkMLModelContainer(CommonSklearnModelContainer):
super(CommonSparkMLModelContainer, self).__init__(sparkml_model) super(CommonSparkMLModelContainer, self).__init__(sparkml_model)
# Output containers # Output containers.
# Abstract containers enabling the Sklearn API.
class SklearnContainer(ABC): class SklearnContainer(ABC):
def __init__(self, model, n_threads=None, batch_size=None, extra_config={}): def __init__(self, model, n_threads=None, batch_size=None, extra_config={}):
""" """
@ -62,6 +64,7 @@ class SklearnContainer(ABC):
self._n_threads = n_threads self._n_threads = n_threads
self._batch_size = batch_size self._batch_size = batch_size
self._extra_config = extra_config self._extra_config = extra_config
self._last_iteration = False
@property @property
def model(self): def model(self):
@ -106,6 +109,9 @@ class SklearnContainer(ABC):
batch = tuple([input[start:end, :] for input in inputs]) batch = tuple([input[start:end, :] for input in inputs])
else: else:
batch = inputs[start:end, :] batch = inputs[start:end, :]
# Tell function that we are in the last iteration and do proper actions in case
# (e.g., for TVM we may want to use the raminder model).
self._last_iteration = i == iterations - 1
predictions.extend(function(*batch).ravel()) predictions.extend(function(*batch).ravel())
if reshape: if reshape:
@ -113,20 +119,17 @@ class SklearnContainer(ABC):
return np.array(predictions).ravel() return np.array(predictions).ravel()
class PyTorchTorchscriptSklearnContainer(SklearnContainer): class SklearnContainerTransformer(SklearnContainer):
""" """
Base container for PyTorch and TorchScript models. Abstract container mirroring Sklearn transformers API.
""" """
@abstractmethod
# PyTorch containers. def _transform(self, *input):
class PyTorchSklearnContainerTransformer(PyTorchTorchscriptSklearnContainer): """
""" This method contains container-specific implementation of transform.
Container mirroring Sklearn transformers API. """
""" pass
def _transform(self, *inputs):
return self.model.forward(*inputs).cpu().numpy()
def transform(self, *inputs): def transform(self, *inputs):
""" """
@ -136,28 +139,27 @@ class PyTorchSklearnContainerTransformer(PyTorchTorchscriptSklearnContainer):
return self._run(self._transform, *inputs, reshape=True) return self._run(self._transform, *inputs, reshape=True)
class PyTorchSklearnContainerRegression(PyTorchTorchscriptSklearnContainer): class SklearnContainerRegression(SklearnContainer):
""" """
Container mirroring Sklearn regressors API. Abstract container mirroring Sklearn regressors API.
""" """
def __init__( def __init__(
self, model, n_threads, batch_size, is_regression=True, is_anomaly_detection=False, extra_config={}, **kwargs self, model, n_threads, batch_size, is_regression=True, is_anomaly_detection=False, extra_config={}, **kwargs
): ):
super(PyTorchSklearnContainerRegression, self).__init__(model, n_threads, batch_size, extra_config) super(SklearnContainerRegression, self).__init__(model, n_threads, batch_size, extra_config)
assert not (is_regression and is_anomaly_detection) assert not (is_regression and is_anomaly_detection)
self._is_regression = is_regression self._is_regression = is_regression
self._is_anomaly_detection = is_anomaly_detection self._is_anomaly_detection = is_anomaly_detection
def _predict(self, *inputs): @abstractmethod
if self._is_regression: def _predict(self, *input):
return self.model.forward(*inputs).cpu().numpy().ravel() """
elif self._is_anomaly_detection: This method contains container-specific implementation of predict.
return self.model.forward(*inputs)[0].cpu().numpy().ravel() """
else: pass
return self.model.forward(*inputs)[0].cpu().numpy().ravel()
def predict(self, *inputs): def predict(self, *inputs):
""" """
@ -169,18 +171,22 @@ class PyTorchSklearnContainerRegression(PyTorchTorchscriptSklearnContainer):
return self._run(self._predict, *inputs) return self._run(self._predict, *inputs)
class PyTorchSklearnContainerClassification(PyTorchSklearnContainerRegression): class SklearnContainerClassification(SklearnContainerRegression):
""" """
Container mirroring Sklearn classifiers API. Container mirroring Sklearn classifiers API.
""" """
def __init__(self, model, n_threads, batch_size, extra_config={}): def __init__(self, model, n_threads, batch_size, extra_config={}):
super(PyTorchSklearnContainerClassification, self).__init__( super(SklearnContainerClassification, self).__init__(
model, n_threads, batch_size, is_regression=False, extra_config=extra_config model, n_threads, batch_size, is_regression=False, extra_config=extra_config
) )
@abstractmethod
def _predict_proba(self, *input): def _predict_proba(self, *input):
return self.model.forward(*input)[1].cpu().numpy() """
This method contains container-specific implementation of predict_proba.
"""
pass
def predict_proba(self, *inputs): def predict_proba(self, *inputs):
""" """
@ -190,18 +196,22 @@ class PyTorchSklearnContainerClassification(PyTorchSklearnContainerRegression):
return self._run(self._predict_proba, *inputs, reshape=True) return self._run(self._predict_proba, *inputs, reshape=True)
class PyTorchSklearnContainerAnomalyDetection(PyTorchSklearnContainerRegression): class SklearnContainerAnomalyDetection(SklearnContainerRegression):
""" """
Container mirroring Sklearn anomaly detection API. Container mirroring Sklearn anomaly detection API.
""" """
def __init__(self, model, n_threads, batch_size, extra_config={}): def __init__(self, model, n_threads, batch_size, extra_config={}):
super(PyTorchSklearnContainerAnomalyDetection, self).__init__( super(SklearnContainerAnomalyDetection, self).__init__(
model, n_threads, batch_size, is_regression=False, is_anomaly_detection=True, extra_config=extra_config model, n_threads, batch_size, is_regression=False, is_anomaly_detection=True, extra_config=extra_config
) )
@abstractmethod
def _decision_function(self, *inputs): def _decision_function(self, *inputs):
return self.model.forward(*inputs)[1].cpu().numpy().ravel() """
This method contains container-specific implementation of decision_function.
"""
pass
def decision_function(self, *inputs): def decision_function(self, *inputs):
""" """
@ -223,6 +233,48 @@ class PyTorchSklearnContainerAnomalyDetection(PyTorchSklearnContainerRegression)
return self.decision_function(*inputs) + self._extra_config[constants.OFFSET] return self.decision_function(*inputs) + self._extra_config[constants.OFFSET]
# PyTorch containers.
class PyTorchSklearnContainerTransformer(SklearnContainerTransformer):
"""
Container for PyTorch models mirroring Sklearn transformers API.
"""
def _transform(self, *inputs):
return self.model.forward(*inputs).cpu().numpy()
class PyTorchSklearnContainerRegression(SklearnContainerRegression):
"""
Container for PyTorch models mirroring Sklearn regressor API.
"""
def _predict(self, *inputs):
if self._is_regression:
return self.model.forward(*inputs).cpu().numpy().ravel()
elif self._is_anomaly_detection:
return self.model.forward(*inputs)[0].cpu().numpy().ravel()
else:
return self.model.forward(*inputs)[0].cpu().numpy().ravel()
class PyTorchSklearnContainerClassification(PyTorchSklearnContainerRegression, SklearnContainerClassification):
"""
Container for PyTorch models mirroring Sklearn classifiers API.
"""
def _predict_proba(self, *input):
return self.model.forward(*input)[1].cpu().numpy()
class PyTorchSklearnContainerAnomalyDetection(PyTorchSklearnContainerRegression, SklearnContainerAnomalyDetection):
"""
Container for PyTorch models mirroning the Sklearn anomaly detection API.
"""
def _decision_function(self, *inputs):
return self.model.forward(*inputs)[1].cpu().numpy().ravel()
# TorchScript containers. # TorchScript containers.
def _torchscript_wrapper(device, function, *inputs): def _torchscript_wrapper(device, function, *inputs):
""" """
@ -232,6 +284,14 @@ def _torchscript_wrapper(device, function, *inputs):
inputs = [*inputs] inputs = [*inputs]
with torch.no_grad(): with torch.no_grad():
if type(inputs) == DataFrame and DataFrame is not None:
# Split the dataframe into column ndarrays
inputs = inputs[0]
input_names = list(inputs.columns)
splits = [inputs[input_names[idx]] for idx in range(len(input_names))]
splits = [df.to_numpy().reshape(-1, 1) for df in splits]
inputs = tuple(splits)
# Maps data inputs to the expected type and device. # Maps data inputs to the expected type and device.
for i in range(len(inputs)): for i in range(len(inputs)):
if type(inputs[i]) is np.ndarray: if type(inputs[i]) is np.ndarray:
@ -245,7 +305,7 @@ def _torchscript_wrapper(device, function, *inputs):
class TorchScriptSklearnContainerTransformer(PyTorchSklearnContainerTransformer): class TorchScriptSklearnContainerTransformer(PyTorchSklearnContainerTransformer):
""" """
Container mirroring Sklearn transformers API. Container for TorchScript models mirroring Sklearn transformers API.
""" """
def transform(self, *inputs): def transform(self, *inputs):
@ -258,7 +318,7 @@ class TorchScriptSklearnContainerTransformer(PyTorchSklearnContainerTransformer)
class TorchScriptSklearnContainerRegression(PyTorchSklearnContainerRegression): class TorchScriptSklearnContainerRegression(PyTorchSklearnContainerRegression):
""" """
Container mirroring Sklearn regressors API. Container for TorchScript models mirroring Sklearn regressors API.
""" """
def predict(self, *inputs): def predict(self, *inputs):
@ -271,7 +331,7 @@ class TorchScriptSklearnContainerRegression(PyTorchSklearnContainerRegression):
class TorchScriptSklearnContainerClassification(PyTorchSklearnContainerClassification): class TorchScriptSklearnContainerClassification(PyTorchSklearnContainerClassification):
""" """
Container mirroring Sklearn classifiers API. Container for TorchScript models mirroring Sklearn classifiers API.
""" """
def predict(self, *inputs): def predict(self, *inputs):
@ -291,7 +351,7 @@ class TorchScriptSklearnContainerClassification(PyTorchSklearnContainerClassific
class TorchScriptSklearnContainerAnomalyDetection(PyTorchSklearnContainerAnomalyDetection): class TorchScriptSklearnContainerAnomalyDetection(PyTorchSklearnContainerAnomalyDetection):
""" """
Container mirroring Sklearn anomaly detection API. Container for TorchScript models mirroring Sklearn anomaly detection API.
""" """
def predict(self, *inputs): def predict(self, *inputs):
@ -306,7 +366,11 @@ class TorchScriptSklearnContainerAnomalyDetection(PyTorchSklearnContainerAnomaly
f = super(TorchScriptSklearnContainerAnomalyDetection, self)._decision_function f = super(TorchScriptSklearnContainerAnomalyDetection, self)._decision_function
f_wrapped = lambda x: _torchscript_wrapper(device, f, x) # noqa: E731 f_wrapped = lambda x: _torchscript_wrapper(device, f, x) # noqa: E731
return self._run(f_wrapped, *inputs) scores = self._run(f_wrapped, *inputs)
if constants.IFOREST_THRESHOLD in self._extra_config:
scores += self._extra_config[constants.IFOREST_THRESHOLD]
return scores
def score_samples(self, *inputs): def score_samples(self, *inputs):
device = _get_device(self.model) device = _get_device(self.model)
@ -329,9 +393,6 @@ class ONNXSklearnContainer(SklearnContainer):
if onnx_runtime_installed(): if onnx_runtime_installed():
import onnxruntime as ort import onnxruntime as ort
self._model = model
self._extra_config = extra_config
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
if self._n_threads is not None: if self._n_threads is not None:
sess_options.intra_op_num_threads = self._n_threads sess_options.intra_op_num_threads = self._n_threads
@ -339,7 +400,7 @@ class ONNXSklearnContainer(SklearnContainer):
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
self._session = ort.InferenceSession(self._model.SerializeToString(), sess_options=sess_options) self._session = ort.InferenceSession(self._model.SerializeToString(), sess_options=sess_options)
self._output_names = [self._session.get_outputs()[i].name for i in range(len(self._session.get_outputs()))] self._output_names = [self._session.get_outputs()[i].name for i in range(len(self._session.get_outputs()))]
self.input_names = [input.name for input in self._session.get_inputs()] self._input_names = [input.name for input in self._session.get_inputs()]
else: else:
raise RuntimeError("ONNX Container requires ONNX runtime installed.") raise RuntimeError("ONNX Container requires ONNX runtime installed.")
@ -347,147 +408,154 @@ class ONNXSklearnContainer(SklearnContainer):
""" """
Retrieve the inputs names from the session object. Retrieve the inputs names from the session object.
""" """
if len(inputs) < len(self.input_names): if len(inputs) < len(self._input_names):
inputs = inputs[0] inputs = inputs[0]
assert len(inputs) == len(self.input_names) assert len(inputs) == len(self._input_names)
named_inputs = {} named_inputs = {}
for i in range(len(inputs)): for i in range(len(inputs)):
named_inputs[self.input_names[i]] = np.array(inputs[i]) named_inputs[self._input_names[i]] = np.array(inputs[i])
return named_inputs return named_inputs
class ONNXSklearnContainerTransformer(ONNXSklearnContainer): class ONNXSklearnContainerTransformer(ONNXSklearnContainer, SklearnContainerTransformer):
""" """
Container mirroring Sklearn transformers API. Container for ONNX models mirroring Sklearn transformers API.
""" """
def __init__(self, model, n_threads=None, batch_size=None, extra_config={}):
super(ONNXSklearnContainerTransformer, self).__init__(model, n_threads, batch_size, extra_config)
assert len(self._output_names) == 1
def _transform(self, *inputs): def _transform(self, *inputs):
assert len(self._output_names) == 1
named_inputs = self._get_named_inputs(inputs) named_inputs = self._get_named_inputs(inputs)
return np.array(self._session.run(self._output_names, named_inputs)) return np.array(self._session.run(self._output_names, named_inputs))
def transform(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On data transformers it returns transformed output data
"""
return self._run(self._transform, *inputs, reshape=True)
class ONNXSklearnContainerRegression(ONNXSklearnContainer, SklearnContainerRegression):
class ONNXSklearnContainerRegression(ONNXSklearnContainer):
""" """
Container mirroring Sklearn regressors API. Container for ONNX models mirroring Sklearn regressors API.
""" """
def __init__(
self, model, n_threads=None, batch_size=None, is_regression=True, is_anomaly_detection=False, extra_config={}, **kwargs
):
super(ONNXSklearnContainerRegression, self).__init__(model, n_threads, batch_size, extra_config)
assert not (is_regression and is_anomaly_detection)
if is_regression:
assert len(self._output_names) == 1
self._is_regression = is_regression
self._is_anomaly_detection = is_anomaly_detection
def _predict(self, *inputs): def _predict(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On regression returns the predicted values.
On classification tasks returns the predicted class labels for the input data.
On anomaly detection (e.g. isolation forest) returns the predicted classes (-1 or 1).
"""
named_inputs = self._get_named_inputs(inputs) named_inputs = self._get_named_inputs(inputs)
if self._is_regression: if self._is_regression:
assert len(self._output_names) == 1
return np.array(self._session.run(self._output_names, named_inputs)) return np.array(self._session.run(self._output_names, named_inputs))
elif self._is_anomaly_detection: elif self._is_anomaly_detection:
assert len(self._output_names) == 2
return np.array(self._session.run([self._output_names[0]], named_inputs))[0].ravel() return np.array(self._session.run([self._output_names[0]], named_inputs))[0].ravel()
else: else:
assert len(self._output_names) == 2
return np.array(self._session.run([self._output_names[0]], named_inputs))[0] return np.array(self._session.run([self._output_names[0]], named_inputs))[0]
def predict(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On data transformers it returns transformed output data
"""
return self._run(self._predict, *inputs)
class ONNXSklearnContainerClassification(ONNXSklearnContainerRegression, SklearnContainerClassification):
class ONNXSklearnContainerClassification(ONNXSklearnContainerRegression):
""" """
Container mirroring Sklearn classifiers API. Container for ONNX models mirroring Sklearn classifiers API.
""" """
def __init__(self, model, n_threads=None, batch_size=None, extra_config={}):
super(ONNXSklearnContainerClassification, self).__init__(
model, n_threads, batch_size, is_regression=False, extra_config=extra_config
)
assert len(self._output_names) == 2
def _predict_proba(self, *inputs): def _predict_proba(self, *inputs):
""" assert len(self._output_names) == 2
Utility functions used to emulate the behavior of the Sklearn API.
On classification tasks returns the probability estimates.
"""
named_inputs = self._get_named_inputs(inputs) named_inputs = self._get_named_inputs(inputs)
return self._session.run([self._output_names[1]], named_inputs)[0] return self._session.run([self._output_names[1]], named_inputs)[0]
def predict_proba(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On data transformers it returns transformed output data
"""
return self._run(self._predict_proba, *inputs, reshape=True)
class ONNXSklearnContainerAnomalyDetection(ONNXSklearnContainerRegression, SklearnContainerAnomalyDetection):
class ONNXSklearnContainerAnomalyDetection(ONNXSklearnContainerRegression):
""" """
Container mirroring Sklearn anomaly detection API. Container for ONNX models mirroring Sklearn anomaly detection API.
"""
def _decision_function(self, *inputs):
assert len(self._output_names) == 2
named_inputs = self._get_named_inputs(inputs)
return np.array(self._session.run([self._output_names[1]], named_inputs)[0]).flatten()
# TVM containers.
class TVMSklearnContainer(SklearnContainer):
"""
Base container for TVM models.
The container allows to mirror the Sklearn API.
""" """
def __init__(self, model, n_threads=None, batch_size=None, extra_config={}): def __init__(self, model, n_threads=None, batch_size=None, extra_config={}):
super(ONNXSklearnContainerAnomalyDetection, self).__init__( super(TVMSklearnContainer, self).__init__(model, n_threads, batch_size, extra_config=extra_config)
model, n_threads, batch_size, is_regression=False, is_anomaly_detection=True, extra_config=extra_config
)
assert len(self._output_names) == 2 assert tvm_installed()
import tvm
self._ctx = self._extra_config[constants.TVM_CONTEXT]
self._input_names = self._extra_config[constants.TVM_INPUT_NAMES]
self._remainder_model = None
if constants.TVM_REMAINDER_MODEL in self._extra_config:
self._remainder_model = self._extra_config[constants.TVM_REMAINDER_MODEL]
self._to_tvm_array = lambda x: tvm.nd.array(x, self._ctx)
os.environ["TVM_NUM_THREADS"] = str(self._n_threads)
def _to_tvm_tensor(self, *inputs):
return {self._input_names[i]: self._to_tvm_array(inputs[i]) for i in range(len(inputs))}
class TVMSklearnContainerTransformer(TVMSklearnContainer, SklearnContainerTransformer):
"""
Container for TVM models mirroring Sklearn transformers API.
"""
def _transform(self, *inputs):
if self._last_iteration and self._remainder_model is not None:
self._remainder_model.run(**self._to_tvm_tensor(*inputs))
return self._remainder_model.get_output(0).asnumpy()
self.model.run(**self._to_tvm_tensor(*inputs))
return self.model.get_output(0).asnumpy()
class TVMSklearnContainerRegression(TVMSklearnContainer, SklearnContainerRegression):
"""
Container for TVM models mirroring Sklearn regressors API.
"""
def _predict(self, *inputs):
if self._last_iteration and self._remainder_model is not None:
self._remainder_model.run(**self._to_tvm_tensor(*inputs))
return self._remainder_model.get_output(0).asnumpy().ravel()
self.model.run(**self._to_tvm_tensor(*inputs))
return self.model.get_output(0).asnumpy().ravel()
class TVMSklearnContainerClassification(TVMSklearnContainerRegression, SklearnContainerClassification):
"""
Container for TVM models mirroring Sklearn classifiers API.
"""
def _predict_proba(self, *inputs):
if self._last_iteration and self._remainder_model is not None:
self._remainder_model.run(**self._to_tvm_tensor(*inputs))
return self._remainder_model.get_output(1).asnumpy()
self.model.run(**self._to_tvm_tensor(*inputs))
return self.model.get_output(1).asnumpy()
class TVMSklearnContainerAnomalyDetection(TVMSklearnContainerRegression, SklearnContainerAnomalyDetection):
"""
Container for TVM models mirroring Sklearn anomaly detection API.
"""
def _decision_function(self, *inputs): def _decision_function(self, *inputs):
""" if self._last_iteration and self._remainder_model is not None:
Utility functions used to emulate the behavior of the Sklearn API. self._remainder_model.run(**self._to_tvm_tensor(*inputs))
On anomaly detection (e.g. isolation forest) returns the decision function scores. return self._remainder_model.get_output(1).asnumpy().ravel()
""" else:
named_inputs = self._get_named_inputs(inputs) self.model.run(**self._to_tvm_tensor(*inputs))
return self.model.get_output(1).asnumpy().ravel()
scores = np.array(self._session.run([self._output_names[1]], named_inputs)[0]).flatten()
# Backward compatibility for sklearn <= 0.21
if constants.IFOREST_THRESHOLD in self._extra_config:
scores += self._extra_config[constants.IFOREST_THRESHOLD]
return scores
def decision_function(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On data transformers it returns transformed output data
"""
return self._run(self._decision_function, *inputs)
def score_samples(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On anomaly detection (e.g. isolation forest) returns the decision_function score plus offset_
"""
return self.decision_function(*inputs) + self._extra_config[constants.OFFSET]

Просмотреть файл

@ -15,6 +15,7 @@ from uuid import uuid4
from onnxconverter_common.registration import get_converter from onnxconverter_common.registration import get_converter
import onnx import onnx
import timeit
from hummingbird.ml._container import ( from hummingbird.ml._container import (
PyTorchSklearnContainerRegression, PyTorchSklearnContainerRegression,
@ -29,8 +30,12 @@ from hummingbird.ml._container import (
ONNXSklearnContainerClassification, ONNXSklearnContainerClassification,
ONNXSklearnContainerTransformer, ONNXSklearnContainerTransformer,
ONNXSklearnContainerAnomalyDetection, ONNXSklearnContainerAnomalyDetection,
TVMSklearnContainerRegression,
TVMSklearnContainerClassification,
TVMSklearnContainerTransformer,
TVMSklearnContainerAnomalyDetection,
) )
from hummingbird.ml._utils import pandas_installed, _get_device from hummingbird.ml._utils import pandas_installed, tvm_installed, _get_device
from hummingbird.ml.exceptions import MissingConverter from hummingbird.ml.exceptions import MissingConverter
from hummingbird.ml.operator_converters import constants from hummingbird.ml.operator_converters import constants
@ -40,20 +45,40 @@ else:
DataFrame = None DataFrame = None
def _jit_model(torch_model, trace_input, device, extra_config):
"""
Function used to convert an input pytorch model into torchscript.
"""
if device != "cpu":
trace_input.to(device)
return torch.jit.trace(torch_model, trace_input).eval()
def _get_trace_input_from_test_input(input, batch_size): def _get_trace_input_from_test_input(input, batch_size):
""" """
Utility function used to properly put the inputs into a format understandable by torch. Utility function used to properly put the inputs into a format understandable by torch.
If a not None batch_size is passed, the function generates a tracing input of batch_size.
If input size % batch_size is not 0, a tracing input of the remainder is also generated.
""" """
remainder = None
if type(input) is tuple: if type(input) is tuple:
if batch_size is not None: if batch_size is not None:
trace_input = tuple([torch.from_numpy(i)[0:batch_size, :] for i in input]) trace_input = tuple([torch.from_numpy(i)[0:batch_size, :] for i in input])
if len(input) > 0 and input[0].shape[0] % batch_size != 0:
remainder_size = input[0].shape[0] % batch_size
remainder = tuple([torch.from_numpy(i)[0:remainder_size, :] for i in input])
else: else:
trace_input = tuple([torch.from_numpy(i) for i in input]) trace_input = tuple([torch.from_numpy(i) for i in input])
else: else:
trace_input = torch.from_numpy(input) trace_input = torch.from_numpy(input)
if batch_size is not None: if batch_size is not None:
trace_input = trace_input[0:batch_size, :] batch_input = trace_input[0:batch_size, :]
return trace_input remainder_size = len(input) % batch_size
if remainder_size != 0:
remainder = trace_input[0:remainder_size, :]
trace_input = batch_input
return (trace_input, remainder)
def convert(topology, backend, device, extra_config={}): def convert(topology, backend, device, extra_config={}):
@ -73,8 +98,16 @@ def convert(topology, backend, device, extra_config={}):
assert backend is not None, "Cannot convert a Topology object into backend None." assert backend is not None, "Cannot convert a Topology object into backend None."
assert device is not None, "Cannot convert a Topology object into device None." assert device is not None, "Cannot convert a Topology object into device None."
tvm_backend = None
operator_map = {} operator_map = {}
if tvm_installed():
import tvm
from tvm import relay
from tvm.contrib import graph_runtime
tvm_backend = tvm.__name__
for operator in topology.topological_operator_iterator(): for operator in topology.topological_operator_iterator():
try: try:
converter = get_converter(operator.type) converter = get_converter(operator.type)
@ -128,12 +161,12 @@ def convert(topology, backend, device, extra_config={}):
output_model_name = str(uuid4().hex) + ".onnx" output_model_name = str(uuid4().hex) + ".onnx"
# Put the tracing test input into the right format. # Put the tracing test input into the right format.
trace_input = _get_trace_input_from_test_input(extra_config[constants.TEST_INPUT], batch_size) batch_trace_input, _ = _get_trace_input_from_test_input(extra_config[constants.TEST_INPUT], batch_size)
# Generate the ONNX models # Generate the ONNX models
torch.onnx.export( torch.onnx.export(
torch_model, torch_model,
trace_input, batch_trace_input,
output_model_name, output_model_name,
input_names=topology.raw_model.input_names, input_names=topology.raw_model.input_names,
output_names=topology.raw_model.output_names, output_names=topology.raw_model.output_names,
@ -185,6 +218,77 @@ def convert(topology, backend, device, extra_config={}):
return num_fixed return num_fixed
fix_graph(hb_model.graph) fix_graph(hb_model.graph)
elif backend == tvm_backend:
# First we need to generate the torchscript model.
batch_trace_input, remainder_trace_input = _get_trace_input_from_test_input(
extra_config[constants.TEST_INPUT], batch_size
)
ts_model = _jit_model(torch_model, batch_trace_input, "cpu", extra_config)
if remainder_trace_input is not None:
remainder_ts_model = _jit_model(torch_model, remainder_trace_input, "cpu", extra_config)
# Generate the test input in the TVM format. In case we have a remainder beyond the batch, generate a remainder test input as well.
test_input = [
(
topology.raw_model.input_names[i],
batch_trace_input[i].shape if type(batch_trace_input) is tuple else batch_trace_input.shape,
)
for i in range(len(topology.raw_model.input_names))
]
if remainder_trace_input is not None:
remainder_test_input = [
(
topology.raw_model.input_names[i],
remainder_trace_input[i].shape if type(remainder_trace_input) is tuple else remainder_trace_input.shape,
)
for i in range(len(topology.raw_model.input_names))
]
# Pick the proper target.
if device == "cuda":
target = tvm.target.cuda()
ctx = tvm.gpu()
elif device == "cpu":
target = "llvm"
ctx = tvm.cpu()
elif "llvm" in device:
target = device
ctx = tvm.cpu()
else:
raise RuntimeError("Device {} not recognized".format(device))
# Get configuration parameters.
config = {}
if constants.TVM_MAX_FUSE_DEPTH in extra_config:
config["relay.FuseOps.max_depth"] = extra_config[constants.TVM_MAX_FUSE_DEPTH]
else:
# 50 is a good depth for operator fusion. More than that will probably hurt performance.
# https://github.com/microsoft/hummingbird/issues/232#issuecomment-697979508
config["relay.FuseOps.max_depth"] = 50
# Create the relay version of the model.
model, params = relay.frontend.from_pytorch(ts_model, test_input)
if remainder_trace_input is not None:
remainder_model, remainder_params = relay.frontend.from_pytorch(remainder_ts_model, remainder_test_input)
# Generate the model. We set opt_level=3 to enable all optimizations.
with tvm.transform.PassContext(opt_level=3, config=config):
graph, lib, params = relay.build(model, target=target, params=params)
tvm_model = graph_runtime.create(graph, lib, ctx)
tvm_model.set_input(**params)
if remainder_trace_input is not None:
with tvm.transform.PassContext(opt_level=3, config=config):
graph, lib, params = relay.build(remainder_model, target=target, params=remainder_params)
tvm_remainder_model = graph_runtime.create(graph, lib, ctx)
tvm_remainder_model.set_input(**params)
# In the container we will be using the context to properly configure the input tensors.
extra_config[constants.TVM_CONTEXT] = ctx
extra_config[constants.TVM_INPUT_NAMES] = topology.raw_model.input_names
if remainder_trace_input is not None:
extra_config[constants.TVM_REMAINDER_MODEL] = tvm_remainder_model
hb_model = tvm_model
else: else:
# Set the device for the model. # Set the device for the model.
if device != "cpu": if device != "cpu":
@ -193,18 +297,19 @@ def convert(topology, backend, device, extra_config={}):
# If the backend is tochscript, jit the model. # If the backend is tochscript, jit the model.
if backend == torch.jit.__name__: if backend == torch.jit.__name__:
trace_input = _get_trace_input_from_test_input(extra_config[constants.TEST_INPUT], batch_size) trace_input, _ = _get_trace_input_from_test_input(extra_config[constants.TEST_INPUT], batch_size)
if device != "cpu": if device != "cpu":
trace_input.to(device) trace_input.to(device)
torch_model = torch.jit.trace(torch_model, trace_input).eval() torch_model = torch.jit.trace(torch_model, trace_input).eval()
torch.jit.optimized_execution(torch_model) torch.jit.optimized_execution(torch_model)
hb_model = torch_model hb_model = torch_model
# Return if the container is not needed. # Return if the container is not needed.
if constants.CONTAINER in extra_config and not extra_config[constants.CONTAINER]: if constants.CONTAINER in extra_config and not extra_config[constants.CONTAINER]:
return hb_model return hb_model
# We scan the operators backwards until we find an operator with a defined type # We scan the operators backwards until we find an operator with a defined type.
# This is necessary because ONNX models can have arbitrary operators doing casting, reshaping etc. # This is necessary because ONNX models can have arbitrary operators doing casting, reshaping etc.
idx = len(operators) - 1 idx = len(operators) - 1
while ( while (
@ -232,12 +337,15 @@ def convert(topology, backend, device, extra_config={}):
if idx < 0: if idx < 0:
idx = tmp_idx idx = tmp_idx
# Get the proper container type.
if operator_map[operators[idx].full_name].regression: if operator_map[operators[idx].full_name].regression:
# We are doing a regression task. # We are doing a regression task.
if backend == torch.jit.__name__: if backend == torch.jit.__name__:
container = TorchScriptSklearnContainerRegression container = TorchScriptSklearnContainerRegression
elif backend == onnx.__name__: elif backend == onnx.__name__:
container = ONNXSklearnContainerRegression container = ONNXSklearnContainerRegression
elif backend == tvm_backend:
container = TVMSklearnContainerRegression
else: else:
container = PyTorchSklearnContainerRegression container = PyTorchSklearnContainerRegression
elif operator_map[operators[idx].full_name].anomaly_detection: elif operator_map[operators[idx].full_name].anomaly_detection:
@ -246,6 +354,8 @@ def convert(topology, backend, device, extra_config={}):
container = TorchScriptSklearnContainerAnomalyDetection container = TorchScriptSklearnContainerAnomalyDetection
elif backend == onnx.__name__: elif backend == onnx.__name__:
container = ONNXSklearnContainerAnomalyDetection container = ONNXSklearnContainerAnomalyDetection
elif backend == tvm_backend:
container = TVMSklearnContainerAnomalyDetection
else: else:
container = PyTorchSklearnContainerAnomalyDetection container = PyTorchSklearnContainerAnomalyDetection
elif operator_map[operators[idx].full_name].transformer: elif operator_map[operators[idx].full_name].transformer:
@ -254,6 +364,8 @@ def convert(topology, backend, device, extra_config={}):
container = TorchScriptSklearnContainerTransformer container = TorchScriptSklearnContainerTransformer
elif backend == onnx.__name__: elif backend == onnx.__name__:
container = ONNXSklearnContainerTransformer container = ONNXSklearnContainerTransformer
elif backend == tvm_backend:
container = TVMSklearnContainerTransformer
else: else:
container = PyTorchSklearnContainerTransformer container = PyTorchSklearnContainerTransformer
else: else:
@ -262,6 +374,8 @@ def convert(topology, backend, device, extra_config={}):
container = TorchScriptSklearnContainerClassification container = TorchScriptSklearnContainerClassification
elif backend == onnx.__name__: elif backend == onnx.__name__:
container = ONNXSklearnContainerClassification container = ONNXSklearnContainerClassification
elif backend == tvm_backend:
container = TVMSklearnContainerClassification
else: else:
container = PyTorchSklearnContainerClassification container = PyTorchSklearnContainerClassification

Просмотреть файл

@ -113,6 +113,18 @@ def xgboost_installed():
return True return True
def tvm_installed():
"""
Checks that *TVM* is available.
"""
try:
import tvm
from tvm import relay
except ImportError:
return False
return True
def pandas_installed(): def pandas_installed():
""" """
Checks that *Pandas* is available. Checks that *Pandas* is available.

Просмотреть файл

@ -22,6 +22,7 @@ from ._utils import (
sparkml_installed, sparkml_installed,
is_pandas_dataframe, is_pandas_dataframe,
is_spark_dataframe, is_spark_dataframe,
tvm_installed,
) )
from .exceptions import MissingConverter, MissingBackend from .exceptions import MissingConverter, MissingBackend
from .supported import backends from .supported import backends
@ -70,7 +71,13 @@ def _supported_backend_check_config(model, backend, extra_config):
import onnx import onnx
import torch import torch
if backend is torch.jit.__name__ and constants.TEST_INPUT not in extra_config: tvm_backend = None
if tvm_installed():
import tvm
tvm_backend = tvm.__name__
if (backend == torch.jit.__name__ or backend == tvm_backend) and constants.TEST_INPUT not in extra_config:
raise RuntimeError("Backend {} requires test inputs. Please pass some test input to the convert.".format(backend)) raise RuntimeError("Backend {} requires test inputs. Please pass some test input to the convert.".format(backend))
@ -262,7 +269,7 @@ def convert(model, backend, test_input=None, device="cpu", extra_config={}):
For *LightGBM* and *XGBoost* currently only the Sklearn API is supported. For *LightGBM* and *XGBoost* currently only the Sklearn API is supported.
The detailed list of models and backends can be found at `hummingbird.ml.supported`. The detailed list of models and backends can be found at `hummingbird.ml.supported`.
The *onnx* backend requires either a test_input of a the initial types set through the exta_config parameter. The *onnx* backend requires either a test_input of a the initial types set through the exta_config parameter.
The *torch.jit* backend requires a test_input. The *torch.jit* and *tvm* backends requires a test_input.
[Sklearn]: https://scikit-learn.org/ [Sklearn]: https://scikit-learn.org/
[LightGBM]: https://lightgbm.readthedocs.io/ [LightGBM]: https://lightgbm.readthedocs.io/
[XGBoost]: https://xgboost.readthedocs.io/ [XGBoost]: https://xgboost.readthedocs.io/
@ -276,8 +283,8 @@ def convert(model, backend, test_input=None, device="cpu", extra_config={}):
backend: The target for the conversion backend: The target for the conversion
test_input: Some input data used to trace the model execution. test_input: Some input data used to trace the model execution.
Multiple inputs can be passed as `tuple` objects or pandas Dataframes. Multiple inputs can be passed as `tuple` objects or pandas Dataframes.
When possible, (`numpy`)`arrays` are suggesed. When possible, (`numpy`)`arrays` are suggested.
device: The target device the model should be run. This parameter is only used by the *torch** backends, and device: The target device the model should be run. This parameter is only used by the *torch** backends and *tvm*, and
the devices supported are the one supported by PyTorch, i.e., 'cpu' or 'cuda'. the devices supported are the one supported by PyTorch, i.e., 'cpu' or 'cuda'.
extra_config: Extra configurations to be used by the individual operator converters. extra_config: Extra configurations to be used by the individual operator converters.
The set of supported extra configurations can be found at `hummingbird.ml.supported` The set of supported extra configurations can be found at `hummingbird.ml.supported`

Просмотреть файл

@ -17,7 +17,7 @@ class LinearModel(BaseOperator, torch.nn.Module):
def __init__(self, coefficients, intercepts, device, classes=[0], multi_class=None, is_linear_regression=False): def __init__(self, coefficients, intercepts, device, classes=[0], multi_class=None, is_linear_regression=False):
super(LinearModel, self).__init__() super(LinearModel, self).__init__()
self.coefficients = torch.nn.Parameter(torch.from_numpy(coefficients), requires_grad=False) self.coefficients = torch.nn.Parameter(torch.from_numpy(coefficients), requires_grad=False)
self.intercepts = torch.nn.Parameter(torch.from_numpy(intercepts), requires_grad=False) self.intercepts = torch.nn.Parameter(torch.from_numpy(intercepts).view(-1), requires_grad=False)
self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False) self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False)
self.multi_class = multi_class self.multi_class = multi_class
self.regression = is_linear_regression self.regression = is_linear_regression

Просмотреть файл

@ -19,7 +19,9 @@ class BernoulliNBModel(BaseOperator, torch.nn.Module):
super(BernoulliNBModel, self).__init__() super(BernoulliNBModel, self).__init__()
self.classification = True self.classification = True
self.binarize = binarize self.binarize = binarize
self.jll_calc_bias = torch.nn.Parameter(torch.from_numpy(jll_calc_bias.astype("float32")), requires_grad=False) self.jll_calc_bias = torch.nn.Parameter(
torch.from_numpy(jll_calc_bias.astype("float32")).view(-1), requires_grad=False
)
self.feature_log_prob_minus_neg_prob = torch.nn.Parameter( self.feature_log_prob_minus_neg_prob = torch.nn.Parameter(
torch.from_numpy(feature_log_prob_minus_neg_prob.astype("float32")), requires_grad=False torch.from_numpy(feature_log_prob_minus_neg_prob.astype("float32")), requires_grad=False
) )

Просмотреть файл

@ -25,10 +25,10 @@ class Scaler(BaseOperator, torch.nn.Module):
self.scale = scale self.scale = scale
if offset is not None: if offset is not None:
self.offset = torch.nn.Parameter(torch.FloatTensor([offset]), requires_grad=False) self.offset = torch.nn.Parameter(torch.DoubleTensor([offset]), requires_grad=False)
if scale is not None: if scale is not None:
self.scale = torch.nn.Parameter(torch.FloatTensor([scale]), requires_grad=False) self.scale = torch.nn.Parameter(torch.DoubleTensor([scale]), requires_grad=False)
def forward(self, x): def forward(self, x):
if self.offset is not None: if self.offset is not None:

Просмотреть файл

@ -408,6 +408,6 @@ def convert_decision_ensemble_tree_common(
for tree_param in tree_parameters for tree_param in tree_parameters
] ]
if tree_type == TreeImpl.tree_trav: if tree_type == TreeImpl.tree_trav:
return TreeTraversalDecisionTreeImpl(net_parameters, max_depth, n_features, classes) return TreeTraversalDecisionTreeImpl(net_parameters, max_depth, n_features, classes, extra_config)
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav
return PerfectTreeTraversalDecisionTreeImpl(net_parameters, max_depth, n_features, classes) return PerfectTreeTraversalDecisionTreeImpl(net_parameters, max_depth, n_features, classes)

Просмотреть файл

@ -162,7 +162,7 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
if self.anomaly_detection: if self.anomaly_detection:
# Select the class (-1 if negative) and return the score. # Select the class (-1 if negative) and return the score.
return torch.where(x < 0, self.classes[0], self.classes[1]), x return torch.where(x.view(-1) < 0, self.classes[0], self.classes[1]), x
if self.perform_class_select: if self.perform_class_select:
return torch.index_select(self.classes, 0, torch.argmax(x, dim=1)), x return torch.index_select(self.classes, 0, torch.argmax(x, dim=1)), x
@ -175,7 +175,12 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
Class implementing the Tree Traversal strategy in PyTorch for tree-base models. Class implementing the Tree Traversal strategy in PyTorch for tree-base models.
""" """
def __init__(self, tree_parameters, max_depth, n_features, classes, n_classes=None, **kwargs): def _expand_indexes(self, batch_size):
indexes = self.nodes_offset
indexes = indexes.expand(batch_size, self.num_trees)
return indexes.reshape(-1)
def __init__(self, tree_parameters, max_depth, n_features, classes, n_classes=None, extra_config={}, **kwargs):
""" """
Args: Args:
tree_parameters: The parameters defining the tree structure tree_parameters: The parameters defining the tree structure
@ -183,6 +188,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
n_features: The number of features input to the model n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model classes: The classes used for classification. None if implementing a regression model
n_classes: The total number of used classes n_classes: The total number of used classes
extra_config: Extra configuration used to properly implement the source tree
""" """
super(TreeTraversalTreeImpl, self).__init__(tree_parameters, n_features, classes, n_classes, **kwargs) super(TreeTraversalTreeImpl, self).__init__(tree_parameters, n_features, classes, n_classes, **kwargs)
@ -192,8 +198,8 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
self.num_trees = len(tree_parameters) self.num_trees = len(tree_parameters)
self.num_nodes = max([len(tree_parameter[1]) for tree_parameter in tree_parameters]) self.num_nodes = max([len(tree_parameter[1]) for tree_parameter in tree_parameters])
lefts = np.zeros((self.num_trees, self.num_nodes), dtype=np.float32) lefts = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64)
rights = np.zeros((self.num_trees, self.num_nodes), dtype=np.float32) rights = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64)
features = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64) features = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64)
thresholds = np.zeros((self.num_trees, self.num_nodes), dtype=np.float32) thresholds = np.zeros((self.num_trees, self.num_nodes), dtype=np.float32)
@ -220,9 +226,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
return x return x
def forward(self, x): def forward(self, x):
indexes = self.nodes_offset indexes = self._expand_indexes(x.size()[0])
indexes = indexes.expand(x.size()[0], self.num_trees)
indexes = indexes.reshape(-1)
for _ in range(self.max_tree_depth): for _ in range(self.max_tree_depth):
tree_nodes = indexes tree_nodes = indexes
@ -246,7 +250,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
if self.anomaly_detection: if self.anomaly_detection:
# Select the class (-1 if negative) and return the score. # Select the class (-1 if negative) and return the score.
return torch.where(output < 0, self.classes[0], self.classes[1]), output return torch.where(output.view(-1) < 0, self.classes[0], self.classes[1]), output
if self.perform_class_select: if self.perform_class_select:
return torch.index_select(self.classes, 0, torch.argmax(output, dim=1)), output return torch.index_select(self.classes, 0, torch.argmax(output, dim=1)), output
@ -325,11 +329,9 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
for nodes, biases in zip(self.nodes, self.biases): for nodes, biases in zip(self.nodes, self.biases):
gather_indices = torch.index_select(nodes, 0, prev_indices).view(-1, self.num_trees) gather_indices = torch.index_select(nodes, 0, prev_indices).view(-1, self.num_trees)
features = torch.gather(x, 1, gather_indices).view(-1) features = torch.gather(x, 1, gather_indices).view(-1)
prev_indices = factor * prev_indices + torch.ge(features, torch.index_select(biases, 0, prev_indices)).long().view( prev_indices = factor * prev_indices + torch.ge(features, torch.index_select(biases, 0, prev_indices)).long()
-1
)
output = torch.index_select(self.leaf_nodes, 0, prev_indices.view(-1)).view(-1, self.num_trees, self.n_classes) output = torch.index_select(self.leaf_nodes, 0, prev_indices).view(-1, self.num_trees, self.n_classes)
output = self.aggregation(output) output = self.aggregation(output)
@ -338,7 +340,7 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
if self.anomaly_detection: if self.anomaly_detection:
# Select the class (-1 if negative) and return the score. # Select the class (-1 if negative) and return the score.
return torch.where(output < 0, self.classes[0], self.classes[1]), output return torch.where(output.view(-1) < 0, self.classes[0], self.classes[1]), output
if self.perform_class_select: if self.perform_class_select:
return torch.index_select(self.classes, 0, torch.argmax(output, dim=1)), output return torch.index_select(self.classes, 0, torch.argmax(output, dim=1)), output
@ -414,15 +416,18 @@ class TreeTraversalDecisionTreeImpl(TreeTraversalTreeImpl):
Class implementing the Tree Traversal strategy in PyTorch for decision tree models. Class implementing the Tree Traversal strategy in PyTorch for decision tree models.
""" """
def __init__(self, tree_parameters, max_depth, n_features, classes=None): def __init__(self, tree_parameters, max_depth, n_features, classes=None, extra_config={}):
""" """
Args: Args:
tree_parameters: The parameters defining the tree structure tree_parameters: The parameters defining the tree structure
max_depth: The maximum tree-depth in the model max_depth: The maximum tree-depth in the model
n_features: The number of features input to the model n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model classes: The classes used for classification. None if implementing a regression model
extra_config: Extra configuration used to properly implement the source tree
""" """
super(TreeTraversalDecisionTreeImpl, self).__init__(tree_parameters, max_depth, n_features, classes) super(TreeTraversalDecisionTreeImpl, self).__init__(
tree_parameters, max_depth, n_features, classes, extra_config=extra_config
)
def aggregation(self, x): def aggregation(self, x):
output = x.sum(1) output = x.sum(1)
@ -498,7 +503,7 @@ class TreeTraversalGBDTImpl(TreeTraversalTreeImpl):
classes: The classes used for classification. None if implementing a regression model classes: The classes used for classification. None if implementing a regression model
extra_config: Extra configuration used to properly implement the source tree extra_config: Extra configuration used to properly implement the source tree
""" """
super(TreeTraversalGBDTImpl, self).__init__(tree_parameters, max_detph, n_features, classes, 1) super(TreeTraversalGBDTImpl, self).__init__(tree_parameters, max_detph, n_features, classes, 1, extra_config)
self.n_gbdt_classes = 1 self.n_gbdt_classes = 1
self.post_transform = lambda x: x self.post_transform = lambda x: x

Просмотреть файл

@ -38,6 +38,15 @@ ONNX_INITIALIZERS = "onnx_initializers"
ONNX_INPUTS = "onnx_inputs" ONNX_INPUTS = "onnx_inputs"
"""The input of the onnx model.""" """The input of the onnx model."""
TVM_CONTEXT = "tvm_context"
"""The context for TVM containing information on the target."""
TVM_INPUT_NAMES = "tvm_input_names"
"""TVM expects named inputs. This is used to set the names for the inputs."""
TVM_REMAINDER_MODEL = "tvm_remainder_model"
"""TVM is statically compiled and when batching we may need to use a different model for the remainder part of the records."""
TEST_INPUT = "test_input" TEST_INPUT = "test_input"
"""The test input data for models that need to be traced.""" """The test input data for models that need to be traced."""

Просмотреть файл

@ -30,7 +30,7 @@ class SVC(BaseOperator, torch.nn.Module):
self.coef0 = coef0 self.coef0 = coef0
self.n_features = sv.shape[1] self.n_features = sv.shape[1]
self.a = a self.a = a
self.b = torch.nn.Parameter(torch.nn.Parameter(torch.from_numpy(b.reshape(1, -1)).double()), requires_grad=False) self.b = torch.nn.Parameter(torch.from_numpy(b.reshape(1, -1)).double(), requires_grad=False)
self.start = [sum(nv[:i]) for i in range(len(nv))] self.start = [sum(nv[:i]) for i in range(len(nv))]
self.end = [self.start[i] + nv[i] for i in range(len(nv))] self.end = [self.start[i] + nv[i] for i in range(len(nv))]
self.len_nv = len(nv) self.len_nv = len(nv)
@ -52,7 +52,7 @@ class SVC(BaseOperator, torch.nn.Module):
# using quadratic expansion--susseptible to rounding-off errors # using quadratic expansion--susseptible to rounding-off errors
# http://www.robots.ox.ac.uk/~albanie/notes/Euclidean_distance_trick.pdf # http://www.robots.ox.ac.uk/~albanie/notes/Euclidean_distance_trick.pdf
x_norm = -self.gamma * (x ** 2).sum(1).view(-1, 1) x_norm = -self.gamma * (x ** 2).sum(1).view(-1, 1)
k = torch.exp(x_norm + self.sv_norm + 2.0 * self.gamma * torch.mm(x, self.sv_t)) k = torch.exp(x_norm + self.sv_norm + 2.0 * self.gamma * torch.mm(x, self.sv_t).double())
elif self.kernel == "sigmoid": elif self.kernel == "sigmoid":
k = torch.sigmoid(self.gamma * torch.mm(x, self.sv_t) + self.coef0) k = torch.sigmoid(self.gamma * torch.mm(x, self.sv_t) + self.coef0)
else: # poly kernel else: # poly kernel
@ -69,10 +69,10 @@ class SVC(BaseOperator, torch.nn.Module):
class_ids = torch.gt(c, 0.0).int().flatten() class_ids = torch.gt(c, 0.0).int().flatten()
else: else:
votes = torch.where(c > 0, self.true_classes, self.false_classes) votes = torch.where(c > 0, self.true_classes, self.false_classes)
# TODO mode is still not implemented for GPU backend # TODO mode is still not implemented for GPU backend.
votes = votes.data.cpu() votes = votes.data.cpu()
class_ids, _ = torch.mode(votes, dim=1) class_ids, _ = torch.mode(votes, dim=1)
# no class probabilities in SVC # No class probabilities in SVC.
if self.perform_class_select: if self.perform_class_select:
temp = torch.index_select(self.classes, 0, class_ids.long()) temp = torch.index_select(self.classes, 0, class_ids.long())
return temp, temp return temp, temp

Просмотреть файл

@ -10,7 +10,8 @@ All operators, backends, and configurations settings supported in Hummingbird ar
**Supported Backends** **Supported Backends**
PyTorch, PyTorch,
TorchScript, TorchScript,
ONNX ONNX,
TVM
**Supported Operators (scikit-learn)** **Supported Operators (scikit-learn)**
BernoulliNB, BernoulliNB,
@ -90,6 +91,7 @@ from ._utils import (
lightgbm_installed, lightgbm_installed,
xgboost_installed, xgboost_installed,
onnx_runtime_installed, onnx_runtime_installed,
tvm_installed,
sparkml_installed, sparkml_installed,
) )
@ -319,6 +321,11 @@ def _build_backend_map():
backends[onnx.__name__] = onnx.__name__ backends[onnx.__name__] = onnx.__name__
if tvm_installed():
import tvm
backends[tvm.__name__] = tvm.__name__
return backends return backends
@ -429,6 +436,11 @@ ONNX_OUTPUT_MODEL_NAME = "onnx_model_name"
ONNX_TARGET_OPSET = "onnx_target_opset" ONNX_TARGET_OPSET = "onnx_target_opset"
"""For ONNX models we can set the target opset to use. 11 by default.""" """For ONNX models we can set the target opset to use. 11 by default."""
TVM_MAX_FUSE_DEPTH = "tvm_max_fuse_depth"
"""For TVM we can fix the number of operations that will be fused.
If not set, compilation may take forever (https://github.com/microsoft/hummingbird/issues/232).
By default Hummingbird uses a max_fuse_depth of 50, but this can be override using this parameter."""
INPUT_NAMES = "input_names" INPUT_NAMES = "input_names"
"""Set the names of the inputs. Assume that the numbers of inputs_names is equal to the number of inputs.""" """Set the names of the inputs. Assume that the numbers of inputs_names is equal to the number of inputs."""

Просмотреть файл

@ -17,7 +17,7 @@ from onnxconverter_common.data_types import (
) )
import hummingbird.ml import hummingbird.ml
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed, tvm_installed
from hummingbird.ml.exceptions import MissingBackend from hummingbird.ml.exceptions import MissingBackend
if onnx_ml_tools_installed(): if onnx_ml_tools_installed():
@ -96,7 +96,24 @@ class TestBackends(unittest.TestCase):
# Test torcscript requires test_input # Test torcscript requires test_input
self.assertRaises(RuntimeError, hummingbird.ml.convert, model, "torch.jit") self.assertRaises(RuntimeError, hummingbird.ml.convert, model, "torch.jit")
# Test onnx no test_data, float input # Test TVM requires test_data
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_tvm_test_data(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
# Test tvm requires test_input
self.assertRaises(RuntimeError, hummingbird.ml.convert, model, "tvm")
# Test onnx requires test_data or initial_types
@unittest.skipIf( @unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS" not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
) )

Просмотреть файл

@ -17,7 +17,13 @@ from sklearn.pipeline import Pipeline
import torch import torch
import hummingbird.ml import hummingbird.ml
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed, pandas_installed, lightgbm_installed from hummingbird.ml._utils import (
onnx_ml_tools_installed,
onnx_runtime_installed,
pandas_installed,
lightgbm_installed,
tvm_installed,
)
from hummingbird.ml import constants from hummingbird.ml import constants
if lightgbm_installed(): if lightgbm_installed():
@ -384,6 +390,156 @@ class TestExtraConf(unittest.TestCase):
np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06) np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06) np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06)
# Test tvm transform with batching.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
def test_tvm_batch_transform(self):
warnings.filterwarnings("ignore")
model = StandardScaler(with_mean=True, with_std=True)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
model.fit(X)
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.transform(X), hb_model.transform(X), rtol=1e-06, atol=1e-06)
# Test tvm regression with batching.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
def test_tvm_regression_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingRegressor(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
# Test tvm classification with batching.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
def test_tvm_classification_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
# Test tvm iforest with batching.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
def test_tvm_iforest_batch(self):
warnings.filterwarnings("ignore")
num_classes = 2
model = IsolationForest(n_estimators=10, max_samples=2)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06)
# Test tvm transform with batching and uneven numer of records.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
def test_tvm_batch_remainder_transform(self):
warnings.filterwarnings("ignore")
model = StandardScaler(with_mean=True, with_std=True)
np.random.seed(0)
X = np.random.rand(105, 200)
X = np.array(X, dtype=np.float32)
model.fit(X)
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.transform(X), hb_model.transform(X), rtol=1e-06, atol=1e-06)
# Test tvm regression with batching and uneven numer of records.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
def test_tvm_regression_remainder_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingRegressor(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(105, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=105)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
# Test tvm classification with batching and uneven numer of records.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
def test_tvm_classification_remainder_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(105, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=105)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
# Test tvm iforest with batching and uneven numer of records.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
def test_tvm_iforest_remainder_batch(self):
warnings.filterwarnings("ignore")
num_classes = 2
model = IsolationForest(n_estimators=10, max_samples=2)
np.random.seed(0)
X = np.random.rand(105, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=105)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06)
# Test batch with pandas. # Test batch with pandas.
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed") @unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
def test_pandas_batch(self): def test_pandas_batch(self):
@ -391,8 +547,8 @@ class TestExtraConf(unittest.TestCase):
max_depth = 10 max_depth = 10
iris = datasets.load_iris() iris = datasets.load_iris()
X = iris.data[:, :3] X = iris.data[:149, :3]
y = iris.target y = iris.target[:149]
columns = ["vA", "vB", "vC"] columns = ["vA", "vB", "vC"]
X_train = pandas.DataFrame(X, columns=columns) X_train = pandas.DataFrame(X, columns=columns)
@ -413,15 +569,15 @@ class TestExtraConf(unittest.TestCase):
pipeline.predict_proba(X_train), torch_model.predict_proba(X_train), rtol=1e-06, atol=1e-06, pipeline.predict_proba(X_train), torch_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
) )
# Test batch with pandas. # Test batch with pandas ts.
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed") @unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
def test_pandas_batch_ts(self): def test_pandas_batch_ts(self):
import pandas import pandas
max_depth = 10 max_depth = 10
iris = datasets.load_iris() iris = datasets.load_iris()
X = iris.data[:, :3] X = iris.data[:149, :3]
y = iris.target y = iris.target[:149]
columns = ["vA", "vB", "vC"] columns = ["vA", "vB", "vC"]
X_train = pandas.DataFrame(X, columns=columns) X_train = pandas.DataFrame(X, columns=columns)
@ -442,7 +598,7 @@ class TestExtraConf(unittest.TestCase):
pipeline.predict_proba(X_train), torch_model.predict_proba(X_train), rtol=1e-06, atol=1e-06, pipeline.predict_proba(X_train), torch_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
) )
# Test batch with pandas. # Test batch with pandas onnx.
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed") @unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
@unittest.skipIf(not onnx_runtime_installed(), reason="ONNXML test require ONNX and ORT") @unittest.skipIf(not onnx_runtime_installed(), reason="ONNXML test require ONNX and ORT")
def test_pandas_batch_onnx(self): def test_pandas_batch_onnx(self):
@ -450,8 +606,8 @@ class TestExtraConf(unittest.TestCase):
max_depth = 10 max_depth = 10
iris = datasets.load_iris() iris = datasets.load_iris()
X = iris.data[:, :3] X = iris.data[:149, :3]
y = iris.target y = iris.target[:149]
columns = ["vA", "vB", "vC"] columns = ["vA", "vB", "vC"]
X_train = pandas.DataFrame(X, columns=columns) X_train = pandas.DataFrame(X, columns=columns)
@ -472,7 +628,7 @@ class TestExtraConf(unittest.TestCase):
pipeline.predict_proba(X_train), hb_model.predict_proba(X_train), rtol=1e-06, atol=1e-06, pipeline.predict_proba(X_train), hb_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
) )
# Test batch with pandas. # Test batch with pandas from onnxml.
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed") @unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
@unittest.skipIf( @unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS" not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
@ -515,6 +671,36 @@ class TestExtraConf(unittest.TestCase):
pipeline.predict_proba(X_train), hb_model.predict_proba(X_train), rtol=1e-06, atol=1e-06, pipeline.predict_proba(X_train), hb_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
) )
# Test batch with pandas tvm.
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM")
def test_pandas_batch_tvm(self):
import pandas
max_depth = 10
iris = datasets.load_iris()
X = iris.data[:149, :3]
y = iris.target[:149]
columns = ["vA", "vB", "vC"]
X_train = pandas.DataFrame(X, columns=columns)
pipeline = Pipeline(
steps=[
("preprocessor", ColumnTransformer(transformers=[], remainder="passthrough",)),
("classifier", GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)),
]
)
pipeline.fit(X_train, y)
hb_model = hummingbird.ml.convert(pipeline, "tvm", X_train, extra_config={constants.BATCH_SIZE: 10})
self.assertTrue(hb_model is not None)
np.testing.assert_allclose(
pipeline.predict_proba(X_train), hb_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
)
# Check converter with model name set as extra_config. # Check converter with model name set as extra_config.
@unittest.skipIf( @unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS" not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
@ -539,6 +725,21 @@ class TestExtraConf(unittest.TestCase):
assert onnx_model.model.graph.name == model_name assert onnx_model.model.graph.name == model_name
# Test max fuse depth configuration in TVM
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_xgb_classifier_converter_tvm(self):
warnings.filterwarnings("ignore")
X = [[0, 1], [1, 1], [2, 0]]
X = np.array(X, dtype=np.float32)
y = np.array([100, -10, 50], dtype=np.float32)
model = lgb.LGBMRegressor(n_estimators=3, min_child_samples=1)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Просмотреть файл

@ -7,7 +7,7 @@ import warnings
import numpy as np import numpy as np
import hummingbird.ml import hummingbird.ml
from hummingbird.ml._utils import lightgbm_installed, onnx_runtime_installed from hummingbird.ml._utils import lightgbm_installed, onnx_runtime_installed, tvm_installed
from tree_utils import gbdt_implementation_map from tree_utils import gbdt_implementation_map
if lightgbm_installed(): if lightgbm_installed():
@ -301,8 +301,6 @@ class TestLGBMConverter(unittest.TestCase):
@unittest.skipIf(not onnx_runtime_installed(), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS") @unittest.skipIf(not onnx_runtime_installed(), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS")
@unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed") @unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed")
def test_lightgbm_onnx(self): def test_lightgbm_onnx(self):
import onnxruntime as ort
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
X = [[0, 1], [1, 1], [2, 0]] X = [[0, 1], [1, 1], [2, 0]]
@ -314,10 +312,63 @@ class TestLGBMConverter(unittest.TestCase):
# Create ONNX model # Create ONNX model
onnx_model = hummingbird.ml.convert(model, "onnx", X) onnx_model = hummingbird.ml.convert(model, "onnx", X)
# Get the predictions for the ONNX-ML model np.testing.assert_allclose(onnx_model.predict(X)[0].flatten(), model.predict(X))
onnx_pred = onnx_model.predict(X)
np.testing.assert_allclose(onnx_pred[0].flatten(), model.predict(X)) # TVM backend tests.
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_lightgbm_tvm_regressor(self):
warnings.filterwarnings("ignore")
for tree_implementation in ["gemm", "tree_trav", "perf_tree_trav"]:
X = [[0, 1], [1, 1], [2, 0]]
X = np.array(X, dtype=np.float32)
y = np.array([100, -10, 50], dtype=np.float32)
model = lgb.LGBMRegressor(n_estimators=3, min_child_samples=1)
model.fit(X, y)
# Create TVM model.
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={"tree_implementation": tree_implementation})
# Check results.
np.testing.assert_allclose(tvm_model.predict(X), model.predict(X))
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM installed")
def test_lightgbm_tvm_classifier(self):
warnings.filterwarnings("ignore")
for tree_implementation in ["gemm", "tree_trav", "perf_tree_trav"]:
X = [[0, 1], [1, 1], [2, 0]]
X = np.array(X, dtype=np.float32)
y = np.array([0, 1, 0], dtype=np.float32)
model = lgb.LGBMClassifier(n_estimators=3, min_child_samples=1)
model.fit(X, y)
# Create TVM model.
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={"tree_implementation": tree_implementation})
# Check results.
np.testing.assert_allclose(tvm_model.predict(X), model.predict(X))
np.testing.assert_allclose(tvm_model.predict_proba(X), model.predict_proba(X))
# Test TVM with large input datasets.
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM installed")
def test_lightgbm_tvm_classifier_large_dataset(self):
warnings.filterwarnings("ignore")
for tree_implementation in ["gemm", "tree_trav", "perf_tree_trav"]:
size = 200000
X = np.random.rand(size, 28)
X = np.array(X, dtype=np.float32)
y = np.random.randint(2, size=size)
model = lgb.LGBMClassifier(n_estimators=100, max_depth=3)
model.fit(X, y)
# Create TVM model.
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={"tree_implementation": tree_implementation})
# Check results.
np.testing.assert_allclose(tvm_model.predict(X), model.predict(X))
np.testing.assert_allclose(tvm_model.predict_proba(X), model.predict_proba(X), rtol=1e-06, atol=1e-06)
if __name__ == "__main__": if __name__ == "__main__":

Просмотреть файл

@ -6,7 +6,13 @@ import warnings
import numpy as np import numpy as np
from hummingbird.ml._utils import lightgbm_installed, xgboost_installed, onnx_runtime_installed, onnx_ml_tools_installed from hummingbird.ml._utils import (
lightgbm_installed,
xgboost_installed,
onnx_runtime_installed,
onnx_ml_tools_installed,
tvm_installed,
)
class TestNoExtra(unittest.TestCase): class TestNoExtra(unittest.TestCase):
@ -39,6 +45,12 @@ class TestNoExtra(unittest.TestCase):
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
assert not onnx_ml_tools_installed() assert not onnx_ml_tools_installed()
# Test no TVM returns false on tvm_installed()
@unittest.skipIf(onnx_ml_tools_installed(), reason="Test when TVM is not installed")
def test_tvm_installed_false(self):
warnings.filterwarnings("ignore")
assert not tvm_installed()
# Test that we can import the converter successfully without installing [extra] # Test that we can import the converter successfully without installing [extra]
def test_import_convert_no_extra(self): def test_import_convert_no_extra(self):
try: try:

Просмотреть файл

@ -51,6 +51,33 @@ class TestONNXOneHotEncoder(unittest.TestCase):
# Check that predicted values match # Check that predicted values match
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol) np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
# Test OneHotEncoder with 2 inputs
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
)
def test_one_hot_encoder_onnx2(self, rtol=1e-06, atol=1e-06):
model = OneHotEncoder()
X = np.array([[1, 2, 3], [2, 1, 3]], dtype=np.int32)
model.fit(X)
# Create ONNX-ML model
onnx_ml_model = convert_sklearn(model, initial_types=[("int_input", IntTensorType_onnx(X.shape))])
# Create ONNX model by calling converter
onnx_model = convert(onnx_ml_model, "onnx", X)
# Get the predictions for the ONNX-ML model
session = ort.InferenceSession(onnx_ml_model.SerializeToString())
output_names = [session.get_outputs()[i].name for i in range(len(session.get_outputs()))]
inputs = {session.get_inputs()[0].name: X}
onnx_ml_pred = session.run(output_names, inputs)
# Get the predictions for the ONNX model
onnx_pred = onnx_model.transform(X)
# Check that predicted values match
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
# Test OneHotEncoder with int64 # Test OneHotEncoder with int64
@unittest.skipIf( @unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS" not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"

Просмотреть файл

@ -8,7 +8,7 @@ import numpy as np
from sklearn.preprocessing import MaxAbsScaler, MinMaxScaler, StandardScaler, RobustScaler from sklearn.preprocessing import MaxAbsScaler, MinMaxScaler, StandardScaler, RobustScaler
import torch import torch
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed, lightgbm_installed from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed
from hummingbird.ml import convert from hummingbird.ml import convert
if onnx_runtime_installed(): if onnx_runtime_installed():

Просмотреть файл

@ -10,6 +10,8 @@ from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
import hummingbird.ml import hummingbird.ml
from hummingbird.ml.exceptions import MissingConverter from hummingbird.ml.exceptions import MissingConverter
from hummingbird.ml._utils import tvm_installed
from hummingbird.ml import constants
from tree_utils import dt_implementation_map from tree_utils import dt_implementation_map
@ -26,7 +28,9 @@ class TestSklearnTreeConverter(unittest.TestCase):
for extra_config_param in ["tree_trav", "perf_tree_trav", "gemm"]: for extra_config_param in ["tree_trav", "perf_tree_trav", "gemm"]:
model.fit(X, y) model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch", extra_config={"tree_implementation": extra_config_param}) torch_model = hummingbird.ml.convert(
model, "torch", extra_config={constants.TREE_IMPLEMENTATION: extra_config_param}
)
self.assertIsNotNone(torch_model) self.assertIsNotNone(torch_model)
self.assertTrue( self.assertTrue(
str(type(list(torch_model.model._operator_map.values())[0])) == dt_implementation_map[extra_config_param] str(type(list(torch_model.model._operator_map.values())[0])) == dt_implementation_map[extra_config_param]
@ -56,19 +60,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
# Random forest gemm classifier # Random forest gemm classifier
def test_random_forest_gemm_classifier_converter(self): def test_random_forest_gemm_classifier_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 2, extra_config={"tree_implementation": "gemm"}, n_estimators=10 RandomForestClassifier, 2, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
) )
# Random forest tree_trav classifier # Random forest tree_trav classifier
def test_random_forest_tree_trav_classifier_converter(self): def test_random_forest_tree_trav_classifier_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 2, extra_config={"tree_implementation": "tree_trav"}, n_estimators=10 RandomForestClassifier, 2, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
) )
# Random forest perf_tree_trav classifier # Random forest perf_tree_trav classifier
def test_random_forest_perf_tree_trav_classifier_converter(self): def test_random_forest_perf_tree_trav_classifier_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 2, extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10 RandomForestClassifier, 2, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}, n_estimators=10
) )
# Random forest multi classifier # Random forest multi classifier
@ -78,37 +82,45 @@ class TestSklearnTreeConverter(unittest.TestCase):
# Random forest gemm multi classifier # Random forest gemm multi classifier
def test_random_forest_gemm_multi_classifier_converter(self): def test_random_forest_gemm_multi_classifier_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 3, extra_config={"tree_implementation": "gemm"}, n_estimators=10 RandomForestClassifier, 3, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
) )
# Random forest tree_trav multi classifier # Random forest tree_trav multi classifier
def test_random_forest_tree_trav_multi_classifier_converter(self): def test_random_forest_tree_trav_multi_classifier_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 3, extra_config={"tree_implementation": "tree_trav"}, n_estimators=10 RandomForestClassifier, 3, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
) )
# Random forest perf_tree_trav multi classifier # Random forest perf_tree_trav multi classifier
def test_random_forest_perf_tree_trav_multi_classifier_converter(self): def test_random_forest_perf_tree_trav_multi_classifier_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 3, extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10 RandomForestClassifier, 3, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}, n_estimators=10
) )
# Random forest gemm classifier shifted classes # Random forest gemm classifier shifted classes
def test_random_forest_gemm_classifier_shifted_labels_converter(self): def test_random_forest_gemm_classifier_shifted_labels_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 3, labels_shift=2, extra_config={"tree_implementation": "gemm"}, n_estimators=10 RandomForestClassifier, 3, labels_shift=2, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
) )
# Random forest tree_trav classifier shifted classes # Random forest tree_trav classifier shifted classes
def test_random_forest_tree_trav_classifier_shifted_labels_converter(self): def test_random_forest_tree_trav_classifier_shifted_labels_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 3, labels_shift=2, extra_config={"tree_implementation": "tree_trav"}, n_estimators=10 RandomForestClassifier,
3,
labels_shift=2,
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"},
n_estimators=10,
) )
# Random forest perf_tree_trav classifier shifted classes # Random forest perf_tree_trav classifier shifted classes
def test_random_forest_perf_tree_trav_classifier_shifted_labels_converter(self): def test_random_forest_perf_tree_trav_classifier_shifted_labels_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 3, labels_shift=2, extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10 RandomForestClassifier,
3,
labels_shift=2,
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
n_estimators=10,
) )
# Used for regression tests # Used for regression tests
@ -133,19 +145,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
# Random forest gemm regressor # Random forest gemm regressor
def test_random_forest_gemm_regressor_converter(self): def test_random_forest_gemm_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
RandomForestRegressor, 1000, extra_config={"tree_implementation": "gemm"}, n_estimators=10 RandomForestRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
) )
# Random forest tree_trav regressor # Random forest tree_trav regressor
def test_random_forest_tree_trav_regressor_converter(self): def test_random_forest_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
RandomForestRegressor, 1000, extra_config={"tree_implementation": "tree_trav"}, n_estimators=10 RandomForestRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
) )
# Random forest perf_tree_trav regressor # Random forest perf_tree_trav regressor
def test_random_forest_perf_tree_trav_regressor_converter(self): def test_random_forest_perf_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
RandomForestRegressor, 1000, extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10 RandomForestRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}, n_estimators=10
) )
# Extra trees regressor # Extra trees regressor
@ -155,19 +167,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
# Extra trees gemm regressor # Extra trees gemm regressor
def test_extra_trees_gemm_regressor_converter(self): def test_extra_trees_gemm_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
ExtraTreesRegressor, 1000, extra_config={"tree_implementation": "gemm"}, n_estimators=10 ExtraTreesRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
) )
# Extra trees tree_trav regressor # Extra trees tree_trav regressor
def test_extra_trees_tree_trav_regressor_converter(self): def test_extra_trees_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
ExtraTreesRegressor, 1000, extra_config={"tree_implementation": "tree_trav"}, n_estimators=10 ExtraTreesRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
) )
# Extra trees perf_tree_trav regressor # Extra trees perf_tree_trav regressor
def test_extra_trees_perf_tree_trav_regressor_converter(self): def test_extra_trees_perf_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
ExtraTreesRegressor, 1000, extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10 ExtraTreesRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}, n_estimators=10
) )
# Decision tree regressor # Decision tree regressor
@ -176,15 +188,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
# Decision tree gemm regressor # Decision tree gemm regressor
def test_decision_tree_gemm_regressor_converter(self): def test_decision_tree_gemm_regressor_converter(self):
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, extra_config={"tree_implementation": "gemm"}) self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "gemm"})
# Decision tree tree_trav regressor # Decision tree tree_trav regressor
def test_decision_tree_tree_trav_regressor_converter(self): def test_decision_tree_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, extra_config={"tree_implementation": "tree_trav"}) self._run_tree_regressor_converter(
DecisionTreeRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}
)
# Decision tree perf_tree_trav regressor # Decision tree perf_tree_trav regressor
def test_decision_tree_perf_tree_trav_regressor_converter(self): def test_decision_tree_perf_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, extra_config={"tree_implementation": "perf_tree_trav"}) self._run_tree_regressor_converter(
DecisionTreeRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}
)
# Decision tree classifier # Decision tree classifier
def test_decision_tree_classifier_converter(self): def test_decision_tree_classifier_converter(self):
@ -208,15 +224,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
# Small tree gemm implementation # Small tree gemm implementation
def test_random_forest_gemm_classifier_single_node_tree_converter(self): def test_random_forest_gemm_classifier_single_node_tree_converter(self):
self._run_random_forest_classifier_single_node_tree_converter(extra_config={"tree_implementation": "gemm"}) self._run_random_forest_classifier_single_node_tree_converter(extra_config={constants.TREE_IMPLEMENTATION: "gemm"})
# Small tree tree_trav implementation # Small tree tree_trav implementation
def test_random_forest_tree_trav_classifier_single_node_tree_converter(self): def test_random_forest_tree_trav_classifier_single_node_tree_converter(self):
self._run_random_forest_classifier_single_node_tree_converter(extra_config={"tree_implementation": "tree_trav"}) self._run_random_forest_classifier_single_node_tree_converter(
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}
)
# Small tree perf_tree_trav implementation # Small tree perf_tree_trav implementation
def test_random_forest_perf_tree_trav_classifier_single_node_tree_converter(self): def test_random_forest_perf_tree_trav_classifier_single_node_tree_converter(self):
self._run_random_forest_classifier_single_node_tree_converter(extra_config={"tree_implementation": "perf_tree_trav"}) self._run_random_forest_classifier_single_node_tree_converter(
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}
)
# Float 64 classification test helper # Float 64 classification test helper
def _run_float64_tree_classification_converter(self, model_type, num_classes, extra_config={}, labels_shift=0, **kwargs): def _run_float64_tree_classification_converter(self, model_type, num_classes, extra_config={}, labels_shift=0, **kwargs):
@ -287,7 +307,7 @@ class TestSklearnTreeConverter(unittest.TestCase):
y = np.random.randint(3, size=100) y = np.random.randint(3, size=100)
model = RandomForestClassifier(n_estimators=10).fit(X, y) model = RandomForestClassifier(n_estimators=10).fit(X, y)
self.assertRaises( self.assertRaises(
MissingConverter, hummingbird.ml.convert, model, "torch", extra_config={"tree_implementation": "nonsense"} MissingConverter, hummingbird.ml.convert, model, "torch", extra_config={constants.TREE_IMPLEMENTATION: "nonsense"}
) )
# Test trees with TorchScript backend # Test trees with TorchScript backend
@ -298,19 +318,23 @@ class TestSklearnTreeConverter(unittest.TestCase):
# Random forest gemm classifier # Random forest gemm classifier
def test_random_forest_ts_gemm_classifier_converter(self): def test_random_forest_ts_gemm_classifier_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 2, "torch.jit", extra_config={"tree_implementation": "gemm"}, n_estimators=10 RandomForestClassifier, 2, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
) )
# Random forest tree_trav classifier # Random forest tree_trav classifier
def test_random_forest_ts_tree_trav_classifier_converter(self): def test_random_forest_ts_tree_trav_classifier_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 2, "torch.jit", extra_config={"tree_implementation": "tree_trav"}, n_estimators=10 RandomForestClassifier, 2, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
) )
# Random forest perf_tree_trav classifier # Random forest perf_tree_trav classifier
def test_random_forest_ts_perf_tree_trav_classifier_converter(self): def test_random_forest_ts_perf_tree_trav_classifier_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 2, "torch.jit", extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10 RandomForestClassifier,
2,
"torch.jit",
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
n_estimators=10,
) )
# Random forest multi classifier # Random forest multi classifier
@ -320,19 +344,23 @@ class TestSklearnTreeConverter(unittest.TestCase):
# Random forest gemm multi classifier # Random forest gemm multi classifier
def test_random_forest_ts_gemm_multi_classifier_converter(self): def test_random_forest_ts_gemm_multi_classifier_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 3, "torch.jit", extra_config={"tree_implementation": "gemm"}, n_estimators=10 RandomForestClassifier, 3, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
) )
# Random forest tree_trav multi classifier # Random forest tree_trav multi classifier
def test_random_forest_ts_tree_trav_multi_classifier_converter(self): def test_random_forest_ts_tree_trav_multi_classifier_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 3, "torch.jit", extra_config={"tree_implementation": "tree_trav"}, n_estimators=10 RandomForestClassifier, 3, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
) )
# Random forest perf_tree_trav multi classifier # Random forest perf_tree_trav multi classifier
def test_random_forest_ts_perf_tree_trav_multi_classifier_converter(self): def test_random_forest_ts_perf_tree_trav_multi_classifier_converter(self):
self._run_tree_classification_converter( self._run_tree_classification_converter(
RandomForestClassifier, 3, "torch.jit", extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10 RandomForestClassifier,
3,
"torch.jit",
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
n_estimators=10,
) )
# Random forest gemm classifier shifted classes # Random forest gemm classifier shifted classes
@ -342,7 +370,7 @@ class TestSklearnTreeConverter(unittest.TestCase):
3, 3,
"torch.jit", "torch.jit",
labels_shift=2, labels_shift=2,
extra_config={"tree_implementation": "gemm"}, extra_config={constants.TREE_IMPLEMENTATION: "gemm"},
n_estimators=10, n_estimators=10,
) )
@ -353,7 +381,7 @@ class TestSklearnTreeConverter(unittest.TestCase):
3, 3,
"torch.jit", "torch.jit",
labels_shift=2, labels_shift=2,
extra_config={"tree_implementation": "tree_trav"}, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"},
n_estimators=10, n_estimators=10,
) )
@ -364,7 +392,7 @@ class TestSklearnTreeConverter(unittest.TestCase):
3, 3,
"torch.jit", "torch.jit",
labels_shift=2, labels_shift=2,
extra_config={"tree_implementation": "perf_tree_trav"}, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
n_estimators=10, n_estimators=10,
) )
@ -375,19 +403,27 @@ class TestSklearnTreeConverter(unittest.TestCase):
# Random forest gemm regressor # Random forest gemm regressor
def test_random_forest_ts_gemm_regressor_converter(self): def test_random_forest_ts_gemm_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
RandomForestRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "gemm"}, n_estimators=10 RandomForestRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
) )
# Random forest tree_trav regressor # Random forest tree_trav regressor
def test_random_forest_ts_tree_trav_regressor_converter(self): def test_random_forest_ts_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
RandomForestRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "tree_trav"}, n_estimators=10 RandomForestRegressor,
1000,
"torch.jit",
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"},
n_estimators=10,
) )
# Random forest perf_tree_trav regressor # Random forest perf_tree_trav regressor
def test_random_forest_ts_perf_tree_trav_regressor_converter(self): def test_random_forest_ts_perf_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
RandomForestRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10 RandomForestRegressor,
1000,
"torch.jit",
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
n_estimators=10,
) )
# Extra trees regressor # Extra trees regressor
@ -397,19 +433,23 @@ class TestSklearnTreeConverter(unittest.TestCase):
# Extra trees gemm regressor # Extra trees gemm regressor
def test_extra_trees_ts_gemm_regressor_converter(self): def test_extra_trees_ts_gemm_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
ExtraTreesRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "gemm"}, n_estimators=10 ExtraTreesRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
) )
# Extra trees tree_trav regressor # Extra trees tree_trav regressor
def test_extra_trees_ts_tree_trav_regressor_converter(self): def test_extra_trees_ts_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
ExtraTreesRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "tree_trav"}, n_estimators=10 ExtraTreesRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
) )
# Extra trees perf_tree_trav regressor # Extra trees perf_tree_trav regressor
def test_extra_trees_ts_perf_tree_trav_regressor_converter(self): def test_extra_trees_ts_perf_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
ExtraTreesRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10 ExtraTreesRegressor,
1000,
"torch.jit",
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
n_estimators=10,
) )
# Decision tree regressor # Decision tree regressor
@ -419,19 +459,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
# Decision tree gemm regressor # Decision tree gemm regressor
def test_decision_tree_ts_gemm_regressor_converter(self): def test_decision_tree_ts_gemm_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
DecisionTreeRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "gemm"} DecisionTreeRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}
) )
# Decision tree tree_trav regressor # Decision tree tree_trav regressor
def test_decision_tree_ts_tree_trav_regressor_converter(self): def test_decision_tree_ts_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
DecisionTreeRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "tree_trav"} DecisionTreeRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}
) )
# Decision tree perf_tree_trav regressor # Decision tree perf_tree_trav regressor
def test_decision_tree_ts_perf_tree_trav_regressor_converter(self): def test_decision_tree_ts_perf_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter( self._run_tree_regressor_converter(
DecisionTreeRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "perf_tree_trav"} DecisionTreeRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}
) )
# Decision tree classifier # Decision tree classifier
@ -444,6 +484,224 @@ class TestSklearnTreeConverter(unittest.TestCase):
def test_extra_trees_ts_classifier_converter(self): def test_extra_trees_ts_classifier_converter(self):
self._run_tree_classification_converter(ExtraTreesClassifier, 3, "torch.jit", n_estimators=10) self._run_tree_classification_converter(ExtraTreesClassifier, 3, "torch.jit", n_estimators=10)
# Test trees with TVM backend
# Random forest gemm classifier
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_random_forest_tvm_gemm_classifier_converter(self):
self._run_tree_classification_converter(
RandomForestClassifier,
2,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
n_estimators=10,
)
# Random forest tree_trav classifier
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_random_forest_tvm_tree_trav_classifier_converter(self):
self._run_tree_classification_converter(
RandomForestClassifier,
2,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
n_estimators=10,
)
# Random forest perf_tree_trav classifier
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_random_forest_tvm_perf_tree_trav_classifier_converter(self):
self._run_tree_classification_converter(
RandomForestClassifier,
2,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
n_estimators=10,
)
# Random forest gemm multi classifier
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_random_forest_tvm_gemm_multi_classifier_converter(self):
self._run_tree_classification_converter(
RandomForestClassifier,
3,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
n_estimators=10,
)
# Random forest tree_trav multi classifier
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_random_forest_tvm_tree_trav_multi_classifier_converter(self):
self._run_tree_classification_converter(
RandomForestClassifier,
3,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
n_estimators=10,
)
# Random forest perf_tree_trav multi classifier
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_random_forest_tvm_perf_tree_trav_multi_classifier_converter(self):
self._run_tree_classification_converter(
RandomForestClassifier,
3,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
n_estimators=10,
)
# Random forest gemm classifier shifted classes
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_random_forest_tvm_gemm_classifier_shifted_labels_converter(self):
self._run_tree_classification_converter(
RandomForestClassifier,
3,
"tvm",
labels_shift=2,
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
n_estimators=10,
)
# Random forest tree_trav classifier shifted classes
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_random_forest_tvm_tree_trav_classifier_shifted_labels_converter(self):
self._run_tree_classification_converter(
RandomForestClassifier,
3,
"tvm",
labels_shift=2,
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
n_estimators=10,
)
# Random forest perf_tree_trav classifier shifted classes
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_random_forest_tvm_perf_tree_trav_classifier_shifted_labels_converter(self):
self._run_tree_classification_converter(
RandomForestClassifier,
3,
"tvm",
labels_shift=2,
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 10},
n_estimators=10,
)
# Random forest gemm regressor
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_random_forest_tvm_gemm_regressor_converter(self):
self._run_tree_regressor_converter(
RandomForestRegressor,
1000,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
n_estimators=10,
)
# Random forest tree_trav regressor
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_random_forest_tvm_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter(
RandomForestRegressor,
1000,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
n_estimators=10,
)
# Random forest perf_tree_trav regressor
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_random_forest_tvm_perf_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter(
RandomForestRegressor,
1000,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 10},
n_estimators=10,
)
# Extra trees gemm regressor
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_extra_trees_tvm_gemm_regressor_converter(self):
self._run_tree_regressor_converter(
ExtraTreesRegressor,
1000,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
n_estimators=10,
)
# Extra trees tree_trav regressor
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_extra_trees_tvm_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter(
ExtraTreesRegressor,
1000,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
n_estimators=10,
)
# Extra trees perf_tree_trav regressor
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_extra_trees_tvm_perf_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter(
ExtraTreesRegressor,
1000,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 10},
n_estimators=10,
)
# Decision tree regressor
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_decision_tree_tvm_regressor_converter(self):
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, "tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
# Decision tree gemm regressor
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_decision_tree_tvm_gemm_regressor_converter(self):
self._run_tree_regressor_converter(
DecisionTreeRegressor,
1000,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
)
# Decision tree tree_trav regressor
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_decision_tree_tvm_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter(
DecisionTreeRegressor,
1000,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
)
# Decision tree perf_tree_trav regressor
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_decision_tree_tvm_perf_tree_trav_regressor_converter(self):
self._run_tree_regressor_converter(
DecisionTreeRegressor,
1000,
"tvm",
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 10},
)
# Decision tree classifier
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_decision_tree_tvm_classifier_converter(self):
self._run_tree_classification_converter(
DecisionTreeClassifier, 3, "tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30}
)
# Extra trees classifier
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_extra_trees_tvm_classifier_converter(self):
self._run_tree_classification_converter(
ExtraTreesClassifier, 3, "tvm", n_estimators=10, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30}
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Просмотреть файл

@ -9,8 +9,9 @@ import torch
from sklearn.ensemble import IsolationForest from sklearn.ensemble import IsolationForest
import hummingbird.ml import hummingbird.ml
from hummingbird.ml import constants
from hummingbird.ml._utils import onnx_runtime_installed, tvm_installed
from tree_utils import iforest_implementation_map from tree_utils import iforest_implementation_map
from hummingbird.ml._utils import onnx_runtime_installed
class TestIsolationForestConverter(unittest.TestCase): class TestIsolationForestConverter(unittest.TestCase):
@ -104,6 +105,23 @@ class TestIsolationForestConverter(unittest.TestCase):
np.testing.assert_allclose(model.score_samples(X), onnx_model.score_samples(X), rtol=1e-06, atol=1e-06) np.testing.assert_allclose(model.score_samples(X), onnx_model.score_samples(X), rtol=1e-06, atol=1e-06)
np.testing.assert_array_equal(model.predict(X), onnx_model.predict(X)) np.testing.assert_array_equal(model.predict(X), onnx_model.predict(X))
# Test TVM backend.
@unittest.skipIf(not (tvm_installed()), reason="TVM test requires TVM")
def test_isolation_forest_tvm_converter(self):
warnings.filterwarnings("ignore")
for max_samples in [2 ** 1, 2 ** 3, 2 ** 8, 2 ** 10, 2 ** 12]:
model = IsolationForest(n_estimators=10, max_samples=max_samples)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
model.fit(X)
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06)
np.testing.assert_array_equal(model.predict(X), hb_model.predict(X))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Просмотреть файл

@ -10,6 +10,8 @@ from sklearn.linear_model import LinearRegression, LogisticRegression, SGDClassi
from sklearn import datasets from sklearn import datasets
import hummingbird.ml import hummingbird.ml
from hummingbird.ml._utils import tvm_installed
from hummingbird.ml import constants
class TestSklearnLinearClassifiers(unittest.TestCase): class TestSklearnLinearClassifiers(unittest.TestCase):
@ -224,6 +226,43 @@ class TestSklearnLinearClassifiers(unittest.TestCase):
np.testing.assert_allclose(model.predict(X), ts_model.predict(X), rtol=1e-6, atol=1e-6) np.testing.assert_allclose(model.predict(X), ts_model.predict(X), rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(model.predict_proba(X), ts_model.predict_proba(X), rtol=1e-6, atol=1e-6) np.testing.assert_allclose(model.predict_proba(X), ts_model.predict_proba(X), rtol=1e-6, atol=1e-6)
# Test TVM backends.
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_sgd_classifier_tvm(self):
model = SGDClassifier(loss="log")
np.random.seed(0)
num_classes = 3
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
tvm_model = hummingbird.ml.convert(model, "tvm", X)
self.assertTrue(tvm_model is not None)
np.testing.assert_allclose(model.predict(X), tvm_model.predict(X), rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(model.predict_proba(X), tvm_model.predict_proba(X), rtol=1e-6, atol=1e-6)
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_lr_tvm(self):
model = LinearRegression()
np.random.seed(0)
num_classes = 1000
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
self.assertTrue(tvm_model is not None)
np.testing.assert_allclose(model.predict(X), tvm_model.predict(X), rtol=1e-6, atol=1e-3)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Просмотреть файл

@ -9,12 +9,14 @@ import torch
from sklearn.neural_network import MLPClassifier, MLPRegressor from sklearn.neural_network import MLPClassifier, MLPRegressor
import hummingbird.ml import hummingbird.ml
from hummingbird.ml import constants
from hummingbird.ml._utils import tvm_installed
class TestSklearnMLPClassifier(unittest.TestCase): class TestSklearnMLPClassifier(unittest.TestCase):
# MLPClassifier test function to be parameterized # MLPClassifier test function to be parameterized
def _test_mlp_classifer(self, num_classes, activation="relu", labels_shift=0): def _test_mlp_classifer(self, num_classes, activation="relu", labels_shift=0, backend="torch", extra_config={}):
model = MLPClassifier(hidden_layer_sizes=(100, 100, 50,), activation=activation) model = MLPClassifier(hidden_layer_sizes=(100, 100, 50,), activation=activation)
np.random.seed(0) np.random.seed(0)
X = np.random.rand(100, 200) X = np.random.rand(100, 200)
@ -22,7 +24,7 @@ class TestSklearnMLPClassifier(unittest.TestCase):
y = np.random.randint(num_classes, size=100) + labels_shift y = np.random.randint(num_classes, size=100) + labels_shift
model.fit(X, y) model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch") torch_model = hummingbird.ml.convert(model, backend, X, extra_config=extra_config)
self.assertTrue(torch_model is not None) self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-6, atol=1e-6) np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-6, atol=1e-6)
@ -50,6 +52,37 @@ class TestSklearnMLPClassifier(unittest.TestCase):
def test_mlp_classifer_multi_identity(self): def test_mlp_classifer_multi_identity(self):
self._test_mlp_classifer(3, activation="identity") self._test_mlp_classifer(3, activation="identity")
# Test TVM backend
# MLPClassifier binary
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_mlp_classifer_bi_tvm(self):
self._test_mlp_classifer(2, backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
# MLPClassifier multi-class
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_mlp_classifer_multi_tvm(self):
self._test_mlp_classifer(3, backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
# MLPClassifier multi-class w/ shifted labels
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_mlp_classifer_multi_shifted_labels_tvm(self):
self._test_mlp_classifer(3, labels_shift=3, backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
# MLPClassifier multi-class w/ tanh activation
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_mlp_classifer_multi_logistic_tvm(self):
self._test_mlp_classifer(3, activation="tanh", backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
# MLPClassifier multi-class w/ logistic activation
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_mlp_classifer_multi_tanh_tvm(self):
self._test_mlp_classifer(3, activation="logistic", backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
# MLPClassifier multi-class w/ identity activation
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_mlp_classifer_multi_identity_tvm(self):
self._test_mlp_classifer(3, activation="identity", backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
# MLPRegressor test function to be parameterized # MLPRegressor test function to be parameterized
def _test_mlp_regressor(self, activation="relu"): def _test_mlp_regressor(self, activation="relu"):
model = MLPRegressor(hidden_layer_sizes=(100, 100, 50,), activation=activation) model = MLPRegressor(hidden_layer_sizes=(100, 100, 50,), activation=activation)

Просмотреть файл

@ -9,13 +9,14 @@ import torch
from sklearn.naive_bayes import BernoulliNB, GaussianNB, MultinomialNB from sklearn.naive_bayes import BernoulliNB, GaussianNB, MultinomialNB
import hummingbird.ml import hummingbird.ml
from hummingbird.ml._utils import tvm_installed
class TestSklearnNBClassifier(unittest.TestCase): class TestSklearnNBClassifier(unittest.TestCase):
# BernoulliNB test function to be parameterized # BernoulliNB test function to be parameterized
def _test_bernoulinb_classifer( def _test_bernoulinb_classifer(
self, num_classes, alpha=1.0, binarize=None, fit_prior=False, class_prior=None, labels_shift=0 self, num_classes, alpha=1.0, binarize=None, fit_prior=False, class_prior=None, labels_shift=0, backend="torch"
): ):
model = BernoulliNB(alpha=alpha, binarize=binarize, fit_prior=fit_prior, class_prior=class_prior) model = BernoulliNB(alpha=alpha, binarize=binarize, fit_prior=fit_prior, class_prior=class_prior)
np.random.seed(0) np.random.seed(0)
@ -27,7 +28,7 @@ class TestSklearnNBClassifier(unittest.TestCase):
y = np.random.randint(num_classes, size=100) + labels_shift y = np.random.randint(num_classes, size=100) + labels_shift
model.fit(X, y) model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch") torch_model = hummingbird.ml.convert(model, backend, X)
self.assertTrue(torch_model is not None) self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-6, atol=1e-5) np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-6, atol=1e-5)
@ -61,8 +62,48 @@ class TestSklearnNBClassifier(unittest.TestCase):
def test_bernoulinb_classifer_multi_labels_shift(self): def test_bernoulinb_classifer_multi_labels_shift(self):
self._test_bernoulinb_classifer(3, labels_shift=3) self._test_bernoulinb_classifer(3, labels_shift=3)
# Test TVM backend
# BernoulliNB binary
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_bernoulinb_classifer_bi_tvm(self):
self._test_bernoulinb_classifer(2, backend="tvm")
# BernoulliNB multi-class
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_bernoulinb_classifer_multi_tvm(self):
self._test_bernoulinb_classifer(3, backend="tvm")
# BernoulliNB multi-class w/ modified alpha
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_bernoulinb_classifer_multi_alpha_tvm(self):
self._test_bernoulinb_classifer(3, alpha=0.5, backend="tvm")
# BernoulliNB multi-class w/ binarize
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_bernoulinb_classifer_multi_binarize_tvm(self):
self._test_bernoulinb_classifer(3, binarize=0.5, backend="tvm")
# BernoulliNB multi-class w/ fit prior
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_bernoulinb_classifer_multi_fit_prior_tvm(self):
self._test_bernoulinb_classifer(3, fit_prior=True, backend="tvm")
# BernoulliNB multi-class w/ class prior
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_bernoulinb_classifer_multi_class_prior_tvm(self, backend="tvm"):
np.random.seed(0)
class_prior = np.random.rand(3)
self._test_bernoulinb_classifer(3, class_prior=class_prior)
# BernoulliNB multi-class w/ labels shift
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_bernoulinb_classifer_multi_labels_shift_tvm(self, backend="tvm"):
self._test_bernoulinb_classifer(3, labels_shift=3)
# MultinomialNB test function to be parameterized # MultinomialNB test function to be parameterized
def _test_multinomialnb_classifer(self, num_classes, alpha=1.0, fit_prior=False, class_prior=None, labels_shift=0): def _test_multinomialnb_classifer(
self, num_classes, alpha=1.0, fit_prior=False, class_prior=None, labels_shift=0, backend="torch"
):
model = MultinomialNB(alpha=alpha, fit_prior=fit_prior, class_prior=class_prior) model = MultinomialNB(alpha=alpha, fit_prior=fit_prior, class_prior=class_prior)
np.random.seed(0) np.random.seed(0)
X = np.random.rand(100, 200) X = np.random.rand(100, 200)
@ -70,7 +111,7 @@ class TestSklearnNBClassifier(unittest.TestCase):
y = np.random.randint(num_classes, size=100) + labels_shift y = np.random.randint(num_classes, size=100) + labels_shift
model.fit(X, y) model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch") torch_model = hummingbird.ml.convert(model, backend, X)
self.assertTrue(torch_model is not None) self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-6, atol=1e-5) np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-6, atol=1e-5)
@ -100,8 +141,41 @@ class TestSklearnNBClassifier(unittest.TestCase):
def test_multinomialnb_classifer_multi_labels_shift(self): def test_multinomialnb_classifer_multi_labels_shift(self):
self._test_bernoulinb_classifer(3, labels_shift=3) self._test_bernoulinb_classifer(3, labels_shift=3)
# TVM Backend
# MultinomialNB binary
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_multinomialnb_classifer_bi_tvm(self):
self._test_bernoulinb_classifer(2, backend="tvm")
# MultinomialNB multi-class
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_multinomialnb_classifer_multi_tvm(self):
self._test_bernoulinb_classifer(3, backend="tvm")
# MultinomialNB multi-class w/ modified alpha
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_multinomialnb_classifer_multi_alpha_tvm(self):
self._test_bernoulinb_classifer(3, alpha=0.5, backend="tvm")
# MultinomialNB multi-class w/ fir prior
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_multinomialnb_classifer_multi_fit_prior_tvm(self):
self._test_bernoulinb_classifer(3, fit_prior=True, backend="tvm")
# MultinomialNB multi-class w/ class prior
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_multinomialnb_classifer_multi_class_prior_tvm(self):
np.random.seed(0)
class_prior = np.random.rand(3)
self._test_bernoulinb_classifer(3, class_prior=class_prior, backend="tvm")
# BernoulliNB multi-class w/ labels shift
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_multinomialnb_classifer_multi_labels_shift_tvm(self):
self._test_bernoulinb_classifer(3, labels_shift=3, backend="tvm")
# GaussianNB test function to be parameterized # GaussianNB test function to be parameterized
def _test_gaussiannb_classifer(self, num_classes, priors=None, var_smoothing=1e-9, labels_shift=0): def _test_gaussiannb_classifer(self, num_classes, priors=None, var_smoothing=1e-9, labels_shift=0, backend="torch"):
model = GaussianNB(priors=priors, var_smoothing=var_smoothing) model = GaussianNB(priors=priors, var_smoothing=var_smoothing)
np.random.seed(0) np.random.seed(0)
X = np.random.rand(100, 200) X = np.random.rand(100, 200)
@ -109,9 +183,9 @@ class TestSklearnNBClassifier(unittest.TestCase):
y = np.random.randint(num_classes, size=100) + labels_shift y = np.random.randint(num_classes, size=100) + labels_shift
model.fit(X, y) model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch") torch_model = hummingbird.ml.convert(model, backend, X)
self.assertTrue(torch_model is not None) self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-6, atol=1e-5) np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-5, atol=1e-5)
# GaussianNB binary # GaussianNB binary
def test_gaussiannb_classifer_bi(self): def test_gaussiannb_classifer_bi(self):
@ -136,6 +210,35 @@ class TestSklearnNBClassifier(unittest.TestCase):
def test_gaussiannb_classifer_multi_labels_shift(self): def test_gaussiannb_classifer_multi_labels_shift(self):
self._test_gaussiannb_classifer(3, labels_shift=3) self._test_gaussiannb_classifer(3, labels_shift=3)
# TVM Backend
# GaussianNB binary
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_gaussiannb_classifer_bi_tvm(self):
self._test_gaussiannb_classifer(2, backend="tvm")
# GaussianNB multi-class
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_gaussiannb_classifer_multi_tvm(self):
self._test_gaussiannb_classifer(3, backend="tvm")
# GaussianNB multi-class w/ class prior
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_gaussiannb_classifer_multi_class_prior_tvm(self):
np.random.seed(0)
priors = np.random.rand(3)
priors = priors / np.sum(priors)
self._test_gaussiannb_classifer(3, priors=priors, backend="tvm")
# GaussianNB multi-class w/ modified var_smoothing
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_gaussiannb_classifer_multi_alpha_tvm(self):
self._test_gaussiannb_classifer(3, var_smoothing=1e-2, backend="tvm")
# GaussianNB multi-class w/ labels shift
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_gaussiannb_classifer_multi_labels_shift_tvm(self):
self._test_gaussiannb_classifer(3, labels_shift=3, backend="tvm")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Просмотреть файл

@ -9,6 +9,8 @@ import torch
from sklearn.preprocessing import Normalizer from sklearn.preprocessing import Normalizer
import hummingbird.ml import hummingbird.ml
from hummingbird.ml import constants
from hummingbird.ml._utils import onnx_runtime_installed, tvm_installed
class TestSklearnNormalizer(unittest.TestCase): class TestSklearnNormalizer(unittest.TestCase):
@ -61,6 +63,46 @@ class TestSklearnNormalizer(unittest.TestCase):
model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06, model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06,
) )
# ONNX backend
@unittest.skipIf(not (onnx_runtime_installed()), reason="ONNXML test requires ONNX and ORT")
def test_normalizer_converter_onnx(self):
# Generate a random 2D array with values in [0, 1000)
np.random.seed(0)
data = np.random.rand(100, 200) * 1000
data = np.array(data, dtype=np.float32)
data_tensor = torch.from_numpy(data)
for norm in ["l1", "l2", "max"]:
model = Normalizer(norm=norm)
model.fit(data)
hb_model = hummingbird.ml.convert(model, "onnx", data)
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(
model.transform(data), hb_model.transform(data_tensor)[0], rtol=1e-06, atol=1e-06,
)
# TVM backend
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_normalizer_converter_tvm(self):
# Generate a random 2D array with values in [0, 1000)
np.random.seed(0)
data = np.random.rand(100, 200) * 1000
data = np.array(data, dtype=np.float32)
data_tensor = torch.from_numpy(data)
for norm in ["l1", "l2", "max"]:
model = Normalizer(norm=norm)
model.fit(data)
torch_model = hummingbird.ml.convert(model, "tvm", data, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
self.assertIsNotNone(torch_model)
np.testing.assert_allclose(
model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06,
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Просмотреть файл

@ -7,39 +7,57 @@ import torch
from sklearn.preprocessing import RobustScaler, MaxAbsScaler, MinMaxScaler, StandardScaler from sklearn.preprocessing import RobustScaler, MaxAbsScaler, MinMaxScaler, StandardScaler
import hummingbird.ml import hummingbird.ml
from hummingbird.ml._utils import tvm_installed
from hummingbird.ml import constants
class TestSklearnScalerConverter(unittest.TestCase): class TestSklearnScalerConverter(unittest.TestCase):
def test_robust_scaler_floats(self): def _test_robust_scaler_floats(self, with_centering, with_scaling, backend="torch"):
# Generate a random 2D array with values in [0, 1000) # Generate a random 2D array with values in [0, 1000)
np.random.seed(0) np.random.seed(0)
data = np.random.rand(100, 200) * 1000 data = np.random.rand(100, 200) * 1000
data = np.array(data, dtype=np.float32) data = np.array(data, dtype=np.float32)
data_tensor = torch.from_numpy(data) data_tensor = torch.from_numpy(data)
model = RobustScaler(with_centering=False, with_scaling=False) model = RobustScaler(with_centering=with_centering, with_scaling=with_scaling)
model.fit(data) model.fit(data)
torch_model = hummingbird.ml.convert(model, "torch") torch_model = hummingbird.ml.convert(model, backend, data)
self.assertIsNotNone(torch_model) self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06) np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
model = RobustScaler(with_centering=False, with_scaling=True) def _test_standard_scaler_floats(self, with_mean, with_std, backend="torch"):
# Generate a random 2D array with values in [0, 1000)
np.random.seed(0)
data = np.random.rand(100, 200) * 1000
data = np.array(data, dtype=np.float32)
data_tensor = torch.from_numpy(data)
model = StandardScaler(with_mean=with_mean, with_std=with_std)
model.fit(data) model.fit(data)
torch_model = hummingbird.ml.convert(model, "torch") torch_model = hummingbird.ml.convert(model, backend, data)
self.assertIsNotNone(torch_model) self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06) np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
model = RobustScaler(with_centering=True, with_scaling=False) def test_robust_scaler_floats_torch_false_false(self):
model.fit(data) self._test_robust_scaler_floats(False, False)
torch_model = hummingbird.ml.convert(model, "torch")
self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
model = RobustScaler(with_centering=True, with_scaling=True) def test_robust_scaler_floats_torch_true_false(self):
model.fit(data) self._test_robust_scaler_floats(True, False)
torch_model = hummingbird.ml.convert(model, "torch")
self.assertIsNotNone(torch_model) def test_robust_scaler_floats_torch_falser_true(self):
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06) self._test_robust_scaler_floats(False, True)
def test_robust_scaler_floats_torch_true_true(self):
self._test_robust_scaler_floats(True, True)
def test_standard_scaler_floats_torch_false_false(self):
self._test_standard_scaler_floats(False, False)
def test_standard_scaler_floats_torch_true_false(self):
self._test_standard_scaler_floats(True, False)
def test_standard_scaler_floats_torch_true_true(self):
self._test_standard_scaler_floats(True, True)
def test_max_abs_scaler_floats(self): def test_max_abs_scaler_floats(self):
# Generate a random 2D array with values in [0, 1000) # Generate a random 2D array with values in [0, 1000)
@ -69,31 +87,6 @@ class TestSklearnScalerConverter(unittest.TestCase):
self.assertIsNotNone(torch_model) self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06) np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
def test_standard_scaler_floats(self):
# Generate a random 2D array with values in [0, 1000)
np.random.seed(0)
data = np.random.rand(100, 200) * 1000
data = np.array(data, dtype=np.float32)
data_tensor = torch.from_numpy(data)
model = StandardScaler(with_mean=False, with_std=False)
model.fit(data)
torch_model = hummingbird.ml.convert(model, "torch")
self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
model = StandardScaler(with_mean=True, with_std=False)
model.fit(data)
torch_model = hummingbird.ml.convert(model, "torch")
self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-04, atol=1e-04)
model = StandardScaler(with_mean=True, with_std=True)
model.fit(data)
torch_model = hummingbird.ml.convert(model, "torch")
self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
# Float 64 data tests # Float 64 data tests
def test_float64_robust_scaler_floats(self): def test_float64_robust_scaler_floats(self):
# Generate a random 2D array with values in [0, 1000) # Generate a random 2D array with values in [0, 1000)
@ -107,6 +100,35 @@ class TestSklearnScalerConverter(unittest.TestCase):
self.assertIsNotNone(torch_model) self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06) np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
# Tests with TVM backend
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_standard_scaler_floats_tvm_false_false(self):
self._test_standard_scaler_floats(False, False, "tvm")
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_standard_scaler_floats_tvm_true_false(self):
self._test_standard_scaler_floats(True, False, "tvm")
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_standard_scaler_floats_tvm_true_true(self):
self._test_standard_scaler_floats(True, True, "tvm")
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_robust_scaler_floats_tvm_false_false(self):
self._test_robust_scaler_floats(False, False, "tvm")
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_robust_scaler_floats_tvm_true_false(self):
self._test_robust_scaler_floats(True, False, "tvm")
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_robust_scaler_floats_tvm_false_true(self):
self._test_robust_scaler_floats(False, True, "tvm")
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_robust_scaler_floats_tvm_true_true(self):
self._test_robust_scaler_floats(True, True, "tvm")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Просмотреть файл

@ -8,6 +8,8 @@ import torch
from sklearn.svm import LinearSVC, SVC, NuSVC from sklearn.svm import LinearSVC, SVC, NuSVC
import hummingbird.ml import hummingbird.ml
from hummingbird.ml import constants
from hummingbird.ml._utils import tvm_installed
class TestSklearnSVC(unittest.TestCase): class TestSklearnSVC(unittest.TestCase):
@ -36,7 +38,7 @@ class TestSklearnSVC(unittest.TestCase):
self._test_linear_svc(3, labels_shift=2) self._test_linear_svc(3, labels_shift=2)
# SVC test function to be parameterized # SVC test function to be parameterized
def _test_svc(self, num_classes, kernel="rbf", gamma=None, labels_shift=0): def _test_svc(self, num_classes, kernel="rbf", gamma=None, backend="torch", labels_shift=0, extra_config={}):
if gamma: if gamma:
model = SVC(kernel=kernel, gamma=gamma) model = SVC(kernel=kernel, gamma=gamma)
@ -48,7 +50,7 @@ class TestSklearnSVC(unittest.TestCase):
y = np.random.randint(num_classes, size=100) + labels_shift y = np.random.randint(num_classes, size=100) + labels_shift
model.fit(X, y) model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch") torch_model = hummingbird.ml.convert(model, backend, X, extra_config=extra_config)
self.assertTrue(torch_model is not None) self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6) np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6)
@ -82,7 +84,7 @@ class TestSklearnSVC(unittest.TestCase):
self._test_svc(3, gamma="auto") self._test_svc(3, gamma="auto")
# NuSVC test function to be parameterized # NuSVC test function to be parameterized
def _test_nu_svc(self, num_classes): def _test_nu_svc(self, num_classes, backend="torch", extra_config={}):
model = NuSVC() model = NuSVC()
np.random.seed(0) np.random.seed(0)
X = np.random.rand(100, 200) X = np.random.rand(100, 200)
@ -90,7 +92,7 @@ class TestSklearnSVC(unittest.TestCase):
y = np.random.randint(num_classes, size=100) y = np.random.randint(num_classes, size=100)
model.fit(X, y) model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch") torch_model = hummingbird.ml.convert(model, backend, X, extra_config=extra_config)
self.assertTrue(torch_model is not None) self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6) np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6)
@ -126,6 +128,77 @@ class TestSklearnSVC(unittest.TestCase):
self.assertTrue(torch_model is not None) self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6) np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6)
# Torchscript backend
def test_svc_ts(self):
self._test_svc(2, backend="torch.jit")
# SVC linear kernel
def test_svc_linear_ts(self):
self._test_svc(2, kernel="linear", backend="torch.jit")
# SVC sigmoid kernel
def test_svc_sigmoid_ts(self):
self._test_svc(2, kernel="sigmoid", backend="torch.jit")
# SVC poly kernel
def test_svc_poly_ts(self):
self._test_svc(2, kernel="poly", backend="torch.jit")
# NuSVC binary
def test_nu_svc_bi_ts(self):
self._test_nu_svc(2, backend="torch.jit")
def test_svc_multi_ts(self):
self._test_svc(3, backend="torch.jit")
# TVM backend.
# SVC binary
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_svc_tvm(self):
self._test_svc(2, backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
# SVC linear kernel
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_svc_linear_tvm(self):
self._test_svc(2, kernel="linear", backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
# SVC sigmoid kernel
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_svc_sigmoid_tvm(self):
self._test_svc(2, kernel="sigmoid", backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
# SVC poly kernel
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_svc_poly_tvm(self):
self._test_svc(2, kernel="poly", backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
# NuSVC binary
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
def test_nu_svc_bi_tvm(self):
self._test_nu_svc(2, backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
# Commenting SVC multiclass for TVM because we are currently missing an operator for implementing mode
# https://github.com/pytorch/pytorch/issues/43867.
# # SVC multiclass
# def test_svc_multi(self):
# self._test_svc(3)
# # SVC sigmoid kernel
# def test_svc_sigmoid(self):
# self._test_svc(3, kernel="sigmoid", backend="tvm")
# # SVC poly kernel
# def test_svc_poly(self):
# self._test_svc(3, kernel="poly", backend="tvm")
# # SVC with class labels shifted
# def test_svc_shifted(self):
# self._test_svc(3, labels_shift=2, backend="tvm")
# # SVC with different gamma (default=scale)
# def test_svc_gamma(self):
# self._test_svc(3, gamma="auto", backend="tvm")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Просмотреть файл

@ -7,7 +7,8 @@ import warnings
import numpy as np import numpy as np
import hummingbird.ml import hummingbird.ml
from hummingbird.ml._utils import xgboost_installed from hummingbird.ml._utils import xgboost_installed, tvm_installed
from hummingbird.ml import constants
from tree_utils import gbdt_implementation_map from tree_utils import gbdt_implementation_map
if xgboost_installed(): if xgboost_installed():
@ -102,7 +103,7 @@ class TestXGBoostConverter(unittest.TestCase):
model.fit(X, y, group=[X.shape[0]]) model.fit(X, y, group=[X.shape[0]])
torch_model = hummingbird.ml.convert(model, "torch", X[0:1], extra_config=extra_config) torch_model = hummingbird.ml.convert(model, "torch", X, extra_config=extra_config)
self.assertIsNotNone(torch_model) self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06) np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
@ -136,7 +137,7 @@ class TestXGBoostConverter(unittest.TestCase):
y = np.random.randint(num_classes, size=100) y = np.random.randint(num_classes, size=100)
model.fit(X, y) model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch", X[0:1], extra_config=extra_config) torch_model = hummingbird.ml.convert(model, "torch", X, extra_config=extra_config)
self.assertIsNotNone(torch_model) self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06) np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
@ -225,6 +226,7 @@ class TestXGBoostConverter(unittest.TestCase):
self.assertIsNotNone(torch_model) self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06) np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
# Torchscript backends.
# Test TorchScript backend regression. # Test TorchScript backend regression.
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed") @unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
def test_xgb_regressor_converter_torchscript(self): def test_xgb_regressor_converter_torchscript(self):
@ -263,6 +265,47 @@ class TestXGBoostConverter(unittest.TestCase):
self.assertIsNotNone(torch_model) self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06) np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
# TVM backend tests.
# TVM backend regression.
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_xgb_regressor_converter_tvm(self):
warnings.filterwarnings("ignore")
import torch
for max_depth in [1, 3, 8, 10, 12]:
model = xgb.XGBRegressor(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(1000, size=100)
model.fit(X, y)
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
self.assertIsNotNone(tvm_model)
np.testing.assert_allclose(model.predict(X), tvm_model.predict(X), rtol=1e-06, atol=1e-06)
# Test TVM backend classification.
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_xgb_classifier_converter_tvm(self):
warnings.filterwarnings("ignore")
import torch
for max_depth in [1, 3, 8, 10, 12]:
model = xgb.XGBClassifier(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(2, size=100)
model.fit(X, y)
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
self.assertIsNotNone(tvm_model)
np.testing.assert_allclose(model.predict_proba(X), tvm_model.predict_proba(X), rtol=1e-06, atol=1e-06)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()