Add TVM backend (#236)
* add containers for onnx models * add tvm_installed, initial work on topology * add containers add tvm backend to supported add few tests * fix type error in TVM tree_trav and perf_tree_trav now work * Add TVM_MAX_FUSE_DEPTH option Add BATCH_SIZE option Tree trav generate indexes based on batch size (if available) TVM takes the max fuse detph configuration if set
This commit is contained in:
Родитель
db65391556
Коммит
409c09a937
|
@ -27,32 +27,29 @@ jobs:
|
|||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
# PyTorch for Mac has different pip syntax wrt Win and Linux.
|
||||
# PyTorch stop supporting python 3.5 from version 1.6.
|
||||
# The following cases address the situations above.
|
||||
- name: Install pytorch 1.5.1 if python 3.5 (mac)
|
||||
if: ${{ matrix.python-version == '3.5' && matrix.os == 'macos-latest' }}
|
||||
run: |
|
||||
pip install torch==1.5.1
|
||||
run: pip install torch==1.5.1
|
||||
- name: Install pytorch 1.7.0 if python 3.5 (mac)
|
||||
if: ${{ matrix.python-version != '3.5' && matrix.os == 'macos-latest' }}
|
||||
run: |
|
||||
pip install torch==1.7.0
|
||||
run: pip install torch==1.7.0
|
||||
- name: Install pytorch 1.5.1+cpu if python 3.5 (not mac)
|
||||
if: ${{ matrix.python-version == '3.5' && matrix.os != 'macos-latest' }}
|
||||
run: |
|
||||
pip install torch==1.5.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
run: pip install torch==1.5.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install pytorch 1.7.0+cpu if python > 3.5 (not mac)
|
||||
if: ${{ matrix.python-version != '3.5' && matrix.os != 'macos-latest' }}
|
||||
run: |
|
||||
pip install torch==1.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
run: pip install torch==1.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install basic dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[tests] -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Run basic tests without extra
|
||||
run: |
|
||||
pytest
|
||||
run: pytest
|
||||
- name: Coverage on basic tests without extra
|
||||
run: |
|
||||
coverage run -a -m pytest tests/test_no_extra_install.py
|
||||
run: coverage run -a -m pytest tests/test_no_extra_install.py
|
||||
- name: If mac, install libomp to facilitate lgbm install
|
||||
if: matrix.os == 'macOS-latest'
|
||||
run: |
|
||||
|
@ -67,17 +64,70 @@ jobs:
|
|||
run: |
|
||||
pip install .[extra,onnx,sparkml]
|
||||
pip install pandas
|
||||
- uses: actions/cache@v1
|
||||
# TVM takes forever, we try to cache it.
|
||||
if: ${{ matrix.python-version != '3.5' && matrix.os != 'windows-2019' }}
|
||||
id: cache
|
||||
env:
|
||||
CACHE_NUMBER: 1
|
||||
with:
|
||||
path: ../../../incubator-tvm
|
||||
key: ${{ runner.os }}-${{ env.CACHE_NUMBER }}-tvm-0.7
|
||||
# Getting TVM requires: 1) fetching TVM from github, 2) get LLVM, 3) cmake, 4) make, 5) install python dependecy.
|
||||
# 1 to 4 will be retrieved from the cache.
|
||||
# The pipeline only works for Unix systems. For windows we will have to compile LLVM from source which is a no go.
|
||||
- name: Fetch and prepare TVM for compilation
|
||||
if: ${{ steps.cache.outputs.cache-hit != 'true' && matrix.python-version != '3.5' && matrix.os != 'windows-2019' }}
|
||||
run: |
|
||||
cd ~/
|
||||
git clone https://github.com/apache/incubator-tvm.git
|
||||
cd incubator-tvm
|
||||
git checkout tags/v0.7.0
|
||||
git submodule update --recursive --init
|
||||
cmake -E make_directory build
|
||||
- name: Get LLVM on Linux
|
||||
if: ${{ steps.cache.outputs.cache-hit != 'true' && matrix.python-version != '3.5' && matrix.os == 'ubuntu-latest' }}
|
||||
working-directory: ../../../incubator-tvm
|
||||
run: |
|
||||
wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/clang+llvm-10.0.0-x86_64-linux-gnu-ubuntu-18.04.tar.xz
|
||||
tar -xf clang+llvm-10.0.0-x86_64-linux-gnu-ubuntu-18.04.tar.xz && mv clang+llvm-10.0.0-x86_64-linux-gnu-ubuntu-18.04 llvm
|
||||
- name: Get LLVM on Mac
|
||||
if: ${{ steps.cache.outputs.cache-hit != 'true' && matrix.python-version != '3.5' && matrix.os == 'macos-latest' }}
|
||||
working-directory: ../../../incubator-tvm
|
||||
run: |
|
||||
wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/clang+llvm-10.0.0-x86_64-apple-darwin.tar.xz
|
||||
tar -xf clang+llvm-10.0.0-x86_64-apple-darwin.tar.xz && mv clang+llvm-10.0.0-x86_64-apple-darwin llvm
|
||||
- name: CMake TVM
|
||||
if: ${{ steps.cache.outputs.cache-hit != 'true' && matrix.python-version != '3.5' && matrix.os != 'windows-2019' }}
|
||||
working-directory: ../../../incubator-tvm/build
|
||||
run: >-
|
||||
cmake
|
||||
"-DUSE_RPC=ON"
|
||||
"-DUSE_GRAPH_RUNTIME=ON"
|
||||
"-DUSE_LLVM=../llvm/bin/llvm-config"
|
||||
..
|
||||
- name: Build TVM
|
||||
if: ${{ steps.cache.outputs.cache-hit != 'true' && matrix.python-version != '3.5' && matrix.os != 'windows-2019' }}
|
||||
working-directory: ../../../incubator-tvm/build
|
||||
run: |
|
||||
make -j3
|
||||
- name: Install python TVM
|
||||
if: ${{ matrix.python-version != '3.5' && matrix.os != 'windows-2019' }}
|
||||
working-directory: ../../../incubator-tvm/python
|
||||
run: |
|
||||
python setup.py install
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# The GitHub editor is 127 chars wide
|
||||
flake8 . --count --max-complexity=10 --max-line-length=127 --statistics
|
||||
# We don't run pytest for Linux py3.7 since we do coverage for that case.
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest
|
||||
if: ${{ matrix.python-version == '3.7' && matrix.os != 'ubuntu-latest' }}
|
||||
run: pytest
|
||||
# Run and push coverage only for one of the runs (Linux py3.7).
|
||||
- name: Coverage
|
||||
# run and push coverage only on one of the runs
|
||||
if: ${{ matrix.python-version == '3.7' && matrix.os == 'ubuntu-latest' }}
|
||||
run: |
|
||||
coverage run -a -m pytest tests
|
||||
|
@ -88,7 +138,8 @@ jobs:
|
|||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
- name: Generate Documentation # for some awful reason, this only works with older torch????
|
||||
# Compile and push documentation only for one of the runs (Linux py3.7).
|
||||
- name: Generate Documentation
|
||||
if: ${{ matrix.python-version == '3.7' && matrix.os == 'ubuntu-latest' }}
|
||||
run: |
|
||||
make sphinx-site -C website/
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
## Introduction
|
||||
*Hummingbird* is a library for compiling trained traditional ML models into tensor computations. *Hummingbird* allows users to seamlessly leverage neural network frameworks (such as [PyTorch](https://pytorch.org/)) to accelerate traditional ML models. Thanks to *Hummingbird*, users can benefit from: (1) all the current and future optimizations implemented in neural network frameworks; (2) native hardware acceleration; (3) having a unique platform to support for both traditional and neural network models; and have all of this (4) without having to re-engineer their models.
|
||||
|
||||
Currently, you can use *Hummingbird* to convert your trained traditional ML models into [PyTorch](https://pytorch.org/), [TorchScript](https://pytorch.org/docs/stable/jit.html), and [ONNX](https://onnx.ai/). *Hummingbird* [supports](https://github.com/microsoft/hummingbird/wiki/Supported-Operators) a variety of ML models and featurizers. These models include
|
||||
[scikit-learn](https://scikit-learn.org/stable/) Decision Trees and Random Forest, and also [LightGBM](https://github.com/Microsoft/LightGBM) and [XGBoost](https://github.com/dmlc/xgboost) Classifiers/Regressors. Support for other neural network backends (e.g., [TVM](https://docs.tvm.ai/)) and models is on our [roadmap](https://github.com/microsoft/hummingbird/wiki/Roadmap-for-Upcoming-Features-and-Support).
|
||||
Currently, you can use *Hummingbird* to convert your trained traditional ML models into [PyTorch](https://pytorch.org/), [TorchScript](https://pytorch.org/docs/stable/jit.html), [ONNX](https://onnx.ai/), and [TVM](https://docs.tvm.ai/)). *Hummingbird* [supports](https://github.com/microsoft/hummingbird/wiki/Supported-Operators) a variety of ML models and featurizers. These models include
|
||||
[scikit-learn](https://scikit-learn.org/stable/) Decision Trees and Random Forest, and also [LightGBM](https://github.com/Microsoft/LightGBM) and [XGBoost](https://github.com/dmlc/xgboost) Classifiers/Regressors. Support for other neural network backends and models is on our [roadmap](https://github.com/microsoft/hummingbird/wiki/Roadmap-for-Upcoming-Features-and-Support).
|
||||
|
||||
Hummingbird also provides a convenient uniform "inference" API following the Sklearn API. This allows swapping Sklearn models with Hummingbird-generated ones without having to change the inference code.
|
||||
|
||||
|
|
|
@ -63,11 +63,7 @@ import benchmarks.operators.train as train
|
|||
import benchmarks.operators.score as score
|
||||
from benchmarks.datasets import prepare_dataset, LearningTask
|
||||
|
||||
from hummingbird.ml._utils import (
|
||||
sklearn_installed,
|
||||
onnx_ml_tools_installed,
|
||||
onnx_runtime_installed,
|
||||
)
|
||||
from hummingbird.ml._utils import sklearn_installed, onnx_ml_tools_installed, onnx_runtime_installed, tvm_installed
|
||||
|
||||
ROOT_PATH = Path(__file__).absolute().parent.parent.parent
|
||||
|
||||
|
@ -106,7 +102,6 @@ def get_number_processors(args):
|
|||
|
||||
|
||||
def print_sys_info(args):
|
||||
import onnxruntime
|
||||
import sklearn
|
||||
import torch
|
||||
|
||||
|
@ -114,7 +109,20 @@ def print_sys_info(args):
|
|||
print("OS : %s" % sys.platform)
|
||||
print("Sklearn : %s" % sklearn.__version__)
|
||||
print("PyTorch : %s" % torch.__version__)
|
||||
|
||||
# Optional imports
|
||||
try:
|
||||
import onnxruntime
|
||||
|
||||
print("ORT : %s" % onnxruntime.__version__)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import tvm
|
||||
|
||||
print("TVM : %s" % tvm.__version__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if args.gpu:
|
||||
print("Running on GPU")
|
||||
|
@ -229,7 +237,19 @@ def benchmark(args, dataset_folder, model_folder, dataset):
|
|||
args.operator = op
|
||||
|
||||
if args.backend == "all":
|
||||
args.backend = "onnx-ml,hb-torchscript,hb-onnx"
|
||||
args.backend = "onnx-ml,hb-pytorch,hb-torchscript,hb-onnx"
|
||||
if "hb-tvm" in args.backend:
|
||||
assert (
|
||||
tvm_installed
|
||||
), "To run benchmark with TVM you need to have TVM installed. Either install TVM or remove it from the backends."
|
||||
if "hb-onnx" in args.backend:
|
||||
assert (
|
||||
onnx_runtime_installed
|
||||
), "To run benchmark with ONNX you need to have ONNX runtime installed. Either install ONNX runtime or remove ONNX from the backends."
|
||||
if "onnx-ml" in args.backend:
|
||||
assert (
|
||||
onnx_runtime_installed and onnx_ml_tools_installed
|
||||
), "To run benchmark with ONNX-ML you need to have ONNX runtime and ONNXMLTOOLS installed. Either install ONNX runtime and ONNXMLTOOLS or remove ONNX-ML from the backends."
|
||||
for backend in args.backend.split(","):
|
||||
print("Running '%s' ..." % backend)
|
||||
scorer = score.ScoreBackend.create(backend)
|
||||
|
|
|
@ -44,6 +44,8 @@ class ScoreBackend(ABC):
|
|||
return HBBackend("torch")
|
||||
if name == "hb-torchscript":
|
||||
return HBBackend("torch.jit")
|
||||
if name == "hb-tvm":
|
||||
return HBBackend("tvm")
|
||||
if name == "hb-onnx":
|
||||
return HBBackend("onnx")
|
||||
if name == "onnx-ml":
|
||||
|
@ -118,7 +120,7 @@ class ScoreBackend(ABC):
|
|||
|
||||
class HBBackend(ScoreBackend):
|
||||
def __init__(self, backend):
|
||||
super().__init__()
|
||||
super(HBBackend, self).__init__()
|
||||
self.backend = backend
|
||||
|
||||
def convert(self, model, data, args, model_name):
|
||||
|
|
|
@ -32,6 +32,21 @@ def print_sys_info(args):
|
|||
print("OS : %s" % sys.platform)
|
||||
print("Sklearn: %s" % sklearn.__version__)
|
||||
print("Torch: %s" % torch.__version__)
|
||||
|
||||
# Optional imports
|
||||
try:
|
||||
import onnxruntime
|
||||
|
||||
print("ORT : %s" % onnxruntime.__version__)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import tvm
|
||||
|
||||
print("TVM : %s" % tvm.__version__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if args.gpu:
|
||||
print("Running on GPU")
|
||||
else:
|
||||
|
|
|
@ -15,7 +15,7 @@ import hummingbird.ml
|
|||
class ScoreBackend(ABC):
|
||||
@staticmethod
|
||||
def create(name):
|
||||
if name in ["torch", "torch.jit", "onnx"]:
|
||||
if name in ["torch", "torch.jit", "tvm", "onnx"]:
|
||||
return HBBackend(name)
|
||||
raise ValueError("Unknown backend: " + name)
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ from hummingbird.ml._utils import (
|
|||
sklearn_installed,
|
||||
onnx_ml_tools_installed,
|
||||
onnx_runtime_installed,
|
||||
tvm_installed,
|
||||
)
|
||||
|
||||
|
||||
|
@ -88,12 +89,19 @@ def print_sys_info(args):
|
|||
print("Sklearn : %s" % sklearn.__version__)
|
||||
print("PyTorch : %s" % torch.__version__)
|
||||
|
||||
# Optional imports
|
||||
try:
|
||||
import onnxruntime
|
||||
|
||||
print("ORT : %s" % onnxruntime.__version__)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import tvm
|
||||
|
||||
print("TVM : %s" % tvm.__version__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if args.gpu:
|
||||
print("Running on GPU")
|
||||
|
@ -235,7 +243,19 @@ def benchmark(args, dataset_folder, model_folder, dataset):
|
|||
args.operator = op
|
||||
|
||||
if args.backend == "all":
|
||||
args.backend = "onnx-ml,hb-pytorch,hb-torchscript,hb-onnx"
|
||||
args.backend = "onnx-ml,hb-pytorch,hb-torchscript,hb-onnx,hb-tvm"
|
||||
if "hb-tvm" in args.backend:
|
||||
assert (
|
||||
tvm_installed
|
||||
), "To run benchmark with TVM you need to have TVM installed. Either install TVM or remove it from the backends."
|
||||
if "hb-onnx" in args.backend:
|
||||
assert (
|
||||
onnx_runtime_installed
|
||||
), "To run benchmark with ONNX you need to have ONNX runtime installed. Either install ONNX runtime or remove ONNX from the backends."
|
||||
if "onnx-ml" in args.backend:
|
||||
assert (
|
||||
onnx_runtime_installed and onnx_ml_tools_installed
|
||||
), "To run benchmark with ONNX-ML you need to have ONNX runtime and ONNXMLTOOLS installed. Either install ONNX runtime and ONNXMLTOOLS or remove ONNX-ML from the backends."
|
||||
for backend in args.backend.split(","):
|
||||
print("Running '%s' ..." % backend)
|
||||
scorer = score.ScoreBackend.create(backend)
|
||||
|
@ -302,6 +322,5 @@ if __name__ == "__main__":
|
|||
assert xgboost_installed, "benchmark requires XGBoost"
|
||||
assert lightgbm_installed, "benchmark requires LightGBM"
|
||||
assert sklearn_installed, "benchmark requires sklearn"
|
||||
assert onnx_ml_tools_installed and onnx_runtime_installed, "benchmark requires ORT and ONNXMLTOOLS"
|
||||
|
||||
main()
|
||||
|
|
|
@ -41,6 +41,8 @@ class ScoreBackend(ABC):
|
|||
return HBBackend("torch")
|
||||
if name == "hb-torchscript":
|
||||
return HBBackend("torch.jit")
|
||||
if name == "hb-tvm":
|
||||
return HBBackend("tvm")
|
||||
if name == "hb-onnx":
|
||||
return HBBackend("onnx")
|
||||
if name == "onnx-ml":
|
||||
|
@ -94,7 +96,7 @@ class ScoreBackend(ABC):
|
|||
|
||||
class HBBackend(ScoreBackend):
|
||||
def __init__(self, backend):
|
||||
super().__init__()
|
||||
super(HBBackend, self).__init__()
|
||||
self.backend = backend
|
||||
|
||||
def convert(self, model, data, args, model_name):
|
||||
|
|
|
@ -11,13 +11,14 @@ In Hummingbird we use two types of containers:
|
|||
- containers for output models (e.g., `SklearnContainer`) used to surface output models as unified API format.
|
||||
"""
|
||||
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import numpy as np
|
||||
from onnxconverter_common.container import CommonSklearnModelContainer
|
||||
import torch
|
||||
|
||||
from hummingbird.ml.operator_converters import constants
|
||||
from hummingbird.ml._utils import onnx_runtime_installed, pandas_installed, _get_device
|
||||
from hummingbird.ml._utils import onnx_runtime_installed, tvm_installed, pandas_installed, _get_device
|
||||
|
||||
if pandas_installed():
|
||||
from pandas import DataFrame
|
||||
|
@ -44,7 +45,8 @@ class CommonSparkMLModelContainer(CommonSklearnModelContainer):
|
|||
super(CommonSparkMLModelContainer, self).__init__(sparkml_model)
|
||||
|
||||
|
||||
# Output containers
|
||||
# Output containers.
|
||||
# Abstract containers enabling the Sklearn API.
|
||||
class SklearnContainer(ABC):
|
||||
def __init__(self, model, n_threads=None, batch_size=None, extra_config={}):
|
||||
"""
|
||||
|
@ -62,6 +64,7 @@ class SklearnContainer(ABC):
|
|||
self._n_threads = n_threads
|
||||
self._batch_size = batch_size
|
||||
self._extra_config = extra_config
|
||||
self._last_iteration = False
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
|
@ -106,6 +109,9 @@ class SklearnContainer(ABC):
|
|||
batch = tuple([input[start:end, :] for input in inputs])
|
||||
else:
|
||||
batch = inputs[start:end, :]
|
||||
# Tell function that we are in the last iteration and do proper actions in case
|
||||
# (e.g., for TVM we may want to use the raminder model).
|
||||
self._last_iteration = i == iterations - 1
|
||||
predictions.extend(function(*batch).ravel())
|
||||
|
||||
if reshape:
|
||||
|
@ -113,20 +119,17 @@ class SklearnContainer(ABC):
|
|||
return np.array(predictions).ravel()
|
||||
|
||||
|
||||
class PyTorchTorchscriptSklearnContainer(SklearnContainer):
|
||||
class SklearnContainerTransformer(SklearnContainer):
|
||||
"""
|
||||
Base container for PyTorch and TorchScript models.
|
||||
Abstract container mirroring Sklearn transformers API.
|
||||
"""
|
||||
|
||||
|
||||
# PyTorch containers.
|
||||
class PyTorchSklearnContainerTransformer(PyTorchTorchscriptSklearnContainer):
|
||||
@abstractmethod
|
||||
def _transform(self, *input):
|
||||
"""
|
||||
Container mirroring Sklearn transformers API.
|
||||
This method contains container-specific implementation of transform.
|
||||
"""
|
||||
|
||||
def _transform(self, *inputs):
|
||||
return self.model.forward(*inputs).cpu().numpy()
|
||||
pass
|
||||
|
||||
def transform(self, *inputs):
|
||||
"""
|
||||
|
@ -136,28 +139,27 @@ class PyTorchSklearnContainerTransformer(PyTorchTorchscriptSklearnContainer):
|
|||
return self._run(self._transform, *inputs, reshape=True)
|
||||
|
||||
|
||||
class PyTorchSklearnContainerRegression(PyTorchTorchscriptSklearnContainer):
|
||||
class SklearnContainerRegression(SklearnContainer):
|
||||
"""
|
||||
Container mirroring Sklearn regressors API.
|
||||
Abstract container mirroring Sklearn regressors API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model, n_threads, batch_size, is_regression=True, is_anomaly_detection=False, extra_config={}, **kwargs
|
||||
):
|
||||
super(PyTorchSklearnContainerRegression, self).__init__(model, n_threads, batch_size, extra_config)
|
||||
super(SklearnContainerRegression, self).__init__(model, n_threads, batch_size, extra_config)
|
||||
|
||||
assert not (is_regression and is_anomaly_detection)
|
||||
|
||||
self._is_regression = is_regression
|
||||
self._is_anomaly_detection = is_anomaly_detection
|
||||
|
||||
def _predict(self, *inputs):
|
||||
if self._is_regression:
|
||||
return self.model.forward(*inputs).cpu().numpy().ravel()
|
||||
elif self._is_anomaly_detection:
|
||||
return self.model.forward(*inputs)[0].cpu().numpy().ravel()
|
||||
else:
|
||||
return self.model.forward(*inputs)[0].cpu().numpy().ravel()
|
||||
@abstractmethod
|
||||
def _predict(self, *input):
|
||||
"""
|
||||
This method contains container-specific implementation of predict.
|
||||
"""
|
||||
pass
|
||||
|
||||
def predict(self, *inputs):
|
||||
"""
|
||||
|
@ -169,18 +171,22 @@ class PyTorchSklearnContainerRegression(PyTorchTorchscriptSklearnContainer):
|
|||
return self._run(self._predict, *inputs)
|
||||
|
||||
|
||||
class PyTorchSklearnContainerClassification(PyTorchSklearnContainerRegression):
|
||||
class SklearnContainerClassification(SklearnContainerRegression):
|
||||
"""
|
||||
Container mirroring Sklearn classifiers API.
|
||||
"""
|
||||
|
||||
def __init__(self, model, n_threads, batch_size, extra_config={}):
|
||||
super(PyTorchSklearnContainerClassification, self).__init__(
|
||||
super(SklearnContainerClassification, self).__init__(
|
||||
model, n_threads, batch_size, is_regression=False, extra_config=extra_config
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _predict_proba(self, *input):
|
||||
return self.model.forward(*input)[1].cpu().numpy()
|
||||
"""
|
||||
This method contains container-specific implementation of predict_proba.
|
||||
"""
|
||||
pass
|
||||
|
||||
def predict_proba(self, *inputs):
|
||||
"""
|
||||
|
@ -190,18 +196,22 @@ class PyTorchSklearnContainerClassification(PyTorchSklearnContainerRegression):
|
|||
return self._run(self._predict_proba, *inputs, reshape=True)
|
||||
|
||||
|
||||
class PyTorchSklearnContainerAnomalyDetection(PyTorchSklearnContainerRegression):
|
||||
class SklearnContainerAnomalyDetection(SklearnContainerRegression):
|
||||
"""
|
||||
Container mirroring Sklearn anomaly detection API.
|
||||
"""
|
||||
|
||||
def __init__(self, model, n_threads, batch_size, extra_config={}):
|
||||
super(PyTorchSklearnContainerAnomalyDetection, self).__init__(
|
||||
super(SklearnContainerAnomalyDetection, self).__init__(
|
||||
model, n_threads, batch_size, is_regression=False, is_anomaly_detection=True, extra_config=extra_config
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _decision_function(self, *inputs):
|
||||
return self.model.forward(*inputs)[1].cpu().numpy().ravel()
|
||||
"""
|
||||
This method contains container-specific implementation of decision_function.
|
||||
"""
|
||||
pass
|
||||
|
||||
def decision_function(self, *inputs):
|
||||
"""
|
||||
|
@ -223,6 +233,48 @@ class PyTorchSklearnContainerAnomalyDetection(PyTorchSklearnContainerRegression)
|
|||
return self.decision_function(*inputs) + self._extra_config[constants.OFFSET]
|
||||
|
||||
|
||||
# PyTorch containers.
|
||||
class PyTorchSklearnContainerTransformer(SklearnContainerTransformer):
|
||||
"""
|
||||
Container for PyTorch models mirroring Sklearn transformers API.
|
||||
"""
|
||||
|
||||
def _transform(self, *inputs):
|
||||
return self.model.forward(*inputs).cpu().numpy()
|
||||
|
||||
|
||||
class PyTorchSklearnContainerRegression(SklearnContainerRegression):
|
||||
"""
|
||||
Container for PyTorch models mirroring Sklearn regressor API.
|
||||
"""
|
||||
|
||||
def _predict(self, *inputs):
|
||||
if self._is_regression:
|
||||
return self.model.forward(*inputs).cpu().numpy().ravel()
|
||||
elif self._is_anomaly_detection:
|
||||
return self.model.forward(*inputs)[0].cpu().numpy().ravel()
|
||||
else:
|
||||
return self.model.forward(*inputs)[0].cpu().numpy().ravel()
|
||||
|
||||
|
||||
class PyTorchSklearnContainerClassification(PyTorchSklearnContainerRegression, SklearnContainerClassification):
|
||||
"""
|
||||
Container for PyTorch models mirroring Sklearn classifiers API.
|
||||
"""
|
||||
|
||||
def _predict_proba(self, *input):
|
||||
return self.model.forward(*input)[1].cpu().numpy()
|
||||
|
||||
|
||||
class PyTorchSklearnContainerAnomalyDetection(PyTorchSklearnContainerRegression, SklearnContainerAnomalyDetection):
|
||||
"""
|
||||
Container for PyTorch models mirroning the Sklearn anomaly detection API.
|
||||
"""
|
||||
|
||||
def _decision_function(self, *inputs):
|
||||
return self.model.forward(*inputs)[1].cpu().numpy().ravel()
|
||||
|
||||
|
||||
# TorchScript containers.
|
||||
def _torchscript_wrapper(device, function, *inputs):
|
||||
"""
|
||||
|
@ -232,6 +284,14 @@ def _torchscript_wrapper(device, function, *inputs):
|
|||
inputs = [*inputs]
|
||||
|
||||
with torch.no_grad():
|
||||
if type(inputs) == DataFrame and DataFrame is not None:
|
||||
# Split the dataframe into column ndarrays
|
||||
inputs = inputs[0]
|
||||
input_names = list(inputs.columns)
|
||||
splits = [inputs[input_names[idx]] for idx in range(len(input_names))]
|
||||
splits = [df.to_numpy().reshape(-1, 1) for df in splits]
|
||||
inputs = tuple(splits)
|
||||
|
||||
# Maps data inputs to the expected type and device.
|
||||
for i in range(len(inputs)):
|
||||
if type(inputs[i]) is np.ndarray:
|
||||
|
@ -245,7 +305,7 @@ def _torchscript_wrapper(device, function, *inputs):
|
|||
|
||||
class TorchScriptSklearnContainerTransformer(PyTorchSklearnContainerTransformer):
|
||||
"""
|
||||
Container mirroring Sklearn transformers API.
|
||||
Container for TorchScript models mirroring Sklearn transformers API.
|
||||
"""
|
||||
|
||||
def transform(self, *inputs):
|
||||
|
@ -258,7 +318,7 @@ class TorchScriptSklearnContainerTransformer(PyTorchSklearnContainerTransformer)
|
|||
|
||||
class TorchScriptSklearnContainerRegression(PyTorchSklearnContainerRegression):
|
||||
"""
|
||||
Container mirroring Sklearn regressors API.
|
||||
Container for TorchScript models mirroring Sklearn regressors API.
|
||||
"""
|
||||
|
||||
def predict(self, *inputs):
|
||||
|
@ -271,7 +331,7 @@ class TorchScriptSklearnContainerRegression(PyTorchSklearnContainerRegression):
|
|||
|
||||
class TorchScriptSklearnContainerClassification(PyTorchSklearnContainerClassification):
|
||||
"""
|
||||
Container mirroring Sklearn classifiers API.
|
||||
Container for TorchScript models mirroring Sklearn classifiers API.
|
||||
"""
|
||||
|
||||
def predict(self, *inputs):
|
||||
|
@ -291,7 +351,7 @@ class TorchScriptSklearnContainerClassification(PyTorchSklearnContainerClassific
|
|||
|
||||
class TorchScriptSklearnContainerAnomalyDetection(PyTorchSklearnContainerAnomalyDetection):
|
||||
"""
|
||||
Container mirroring Sklearn anomaly detection API.
|
||||
Container for TorchScript models mirroring Sklearn anomaly detection API.
|
||||
"""
|
||||
|
||||
def predict(self, *inputs):
|
||||
|
@ -306,7 +366,11 @@ class TorchScriptSklearnContainerAnomalyDetection(PyTorchSklearnContainerAnomaly
|
|||
f = super(TorchScriptSklearnContainerAnomalyDetection, self)._decision_function
|
||||
f_wrapped = lambda x: _torchscript_wrapper(device, f, x) # noqa: E731
|
||||
|
||||
return self._run(f_wrapped, *inputs)
|
||||
scores = self._run(f_wrapped, *inputs)
|
||||
|
||||
if constants.IFOREST_THRESHOLD in self._extra_config:
|
||||
scores += self._extra_config[constants.IFOREST_THRESHOLD]
|
||||
return scores
|
||||
|
||||
def score_samples(self, *inputs):
|
||||
device = _get_device(self.model)
|
||||
|
@ -329,9 +393,6 @@ class ONNXSklearnContainer(SklearnContainer):
|
|||
if onnx_runtime_installed():
|
||||
import onnxruntime as ort
|
||||
|
||||
self._model = model
|
||||
self._extra_config = extra_config
|
||||
|
||||
sess_options = ort.SessionOptions()
|
||||
if self._n_threads is not None:
|
||||
sess_options.intra_op_num_threads = self._n_threads
|
||||
|
@ -339,7 +400,7 @@ class ONNXSklearnContainer(SklearnContainer):
|
|||
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
self._session = ort.InferenceSession(self._model.SerializeToString(), sess_options=sess_options)
|
||||
self._output_names = [self._session.get_outputs()[i].name for i in range(len(self._session.get_outputs()))]
|
||||
self.input_names = [input.name for input in self._session.get_inputs()]
|
||||
self._input_names = [input.name for input in self._session.get_inputs()]
|
||||
else:
|
||||
raise RuntimeError("ONNX Container requires ONNX runtime installed.")
|
||||
|
||||
|
@ -347,147 +408,154 @@ class ONNXSklearnContainer(SklearnContainer):
|
|||
"""
|
||||
Retrieve the inputs names from the session object.
|
||||
"""
|
||||
if len(inputs) < len(self.input_names):
|
||||
if len(inputs) < len(self._input_names):
|
||||
inputs = inputs[0]
|
||||
|
||||
assert len(inputs) == len(self.input_names)
|
||||
assert len(inputs) == len(self._input_names)
|
||||
|
||||
named_inputs = {}
|
||||
|
||||
for i in range(len(inputs)):
|
||||
named_inputs[self.input_names[i]] = np.array(inputs[i])
|
||||
named_inputs[self._input_names[i]] = np.array(inputs[i])
|
||||
|
||||
return named_inputs
|
||||
|
||||
|
||||
class ONNXSklearnContainerTransformer(ONNXSklearnContainer):
|
||||
class ONNXSklearnContainerTransformer(ONNXSklearnContainer, SklearnContainerTransformer):
|
||||
"""
|
||||
Container mirroring Sklearn transformers API.
|
||||
Container for ONNX models mirroring Sklearn transformers API.
|
||||
"""
|
||||
|
||||
def __init__(self, model, n_threads=None, batch_size=None, extra_config={}):
|
||||
super(ONNXSklearnContainerTransformer, self).__init__(model, n_threads, batch_size, extra_config)
|
||||
|
||||
assert len(self._output_names) == 1
|
||||
|
||||
def _transform(self, *inputs):
|
||||
assert len(self._output_names) == 1
|
||||
|
||||
named_inputs = self._get_named_inputs(inputs)
|
||||
|
||||
return np.array(self._session.run(self._output_names, named_inputs))
|
||||
|
||||
def transform(self, *inputs):
|
||||
|
||||
class ONNXSklearnContainerRegression(ONNXSklearnContainer, SklearnContainerRegression):
|
||||
"""
|
||||
Utility functions used to emulate the behavior of the Sklearn API.
|
||||
On data transformers it returns transformed output data
|
||||
Container for ONNX models mirroring Sklearn regressors API.
|
||||
"""
|
||||
return self._run(self._transform, *inputs, reshape=True)
|
||||
|
||||
|
||||
class ONNXSklearnContainerRegression(ONNXSklearnContainer):
|
||||
"""
|
||||
Container mirroring Sklearn regressors API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model, n_threads=None, batch_size=None, is_regression=True, is_anomaly_detection=False, extra_config={}, **kwargs
|
||||
):
|
||||
super(ONNXSklearnContainerRegression, self).__init__(model, n_threads, batch_size, extra_config)
|
||||
|
||||
assert not (is_regression and is_anomaly_detection)
|
||||
if is_regression:
|
||||
assert len(self._output_names) == 1
|
||||
|
||||
self._is_regression = is_regression
|
||||
self._is_anomaly_detection = is_anomaly_detection
|
||||
|
||||
def _predict(self, *inputs):
|
||||
"""
|
||||
Utility functions used to emulate the behavior of the Sklearn API.
|
||||
On regression returns the predicted values.
|
||||
On classification tasks returns the predicted class labels for the input data.
|
||||
On anomaly detection (e.g. isolation forest) returns the predicted classes (-1 or 1).
|
||||
"""
|
||||
named_inputs = self._get_named_inputs(inputs)
|
||||
|
||||
if self._is_regression:
|
||||
assert len(self._output_names) == 1
|
||||
|
||||
return np.array(self._session.run(self._output_names, named_inputs))
|
||||
elif self._is_anomaly_detection:
|
||||
return np.array(self._session.run([self._output_names[0]], named_inputs))[0].ravel()
|
||||
else:
|
||||
return np.array(self._session.run([self._output_names[0]], named_inputs))[0]
|
||||
|
||||
def predict(self, *inputs):
|
||||
"""
|
||||
Utility functions used to emulate the behavior of the Sklearn API.
|
||||
On data transformers it returns transformed output data
|
||||
"""
|
||||
return self._run(self._predict, *inputs)
|
||||
|
||||
|
||||
class ONNXSklearnContainerClassification(ONNXSklearnContainerRegression):
|
||||
"""
|
||||
Container mirroring Sklearn classifiers API.
|
||||
"""
|
||||
|
||||
def __init__(self, model, n_threads=None, batch_size=None, extra_config={}):
|
||||
super(ONNXSklearnContainerClassification, self).__init__(
|
||||
model, n_threads, batch_size, is_regression=False, extra_config=extra_config
|
||||
)
|
||||
|
||||
assert len(self._output_names) == 2
|
||||
|
||||
return np.array(self._session.run([self._output_names[0]], named_inputs))[0].ravel()
|
||||
else:
|
||||
assert len(self._output_names) == 2
|
||||
|
||||
return np.array(self._session.run([self._output_names[0]], named_inputs))[0]
|
||||
|
||||
|
||||
class ONNXSklearnContainerClassification(ONNXSklearnContainerRegression, SklearnContainerClassification):
|
||||
"""
|
||||
Container for ONNX models mirroring Sklearn classifiers API.
|
||||
"""
|
||||
|
||||
def _predict_proba(self, *inputs):
|
||||
"""
|
||||
Utility functions used to emulate the behavior of the Sklearn API.
|
||||
On classification tasks returns the probability estimates.
|
||||
"""
|
||||
assert len(self._output_names) == 2
|
||||
|
||||
named_inputs = self._get_named_inputs(inputs)
|
||||
|
||||
return self._session.run([self._output_names[1]], named_inputs)[0]
|
||||
|
||||
def predict_proba(self, *inputs):
|
||||
|
||||
class ONNXSklearnContainerAnomalyDetection(ONNXSklearnContainerRegression, SklearnContainerAnomalyDetection):
|
||||
"""
|
||||
Utility functions used to emulate the behavior of the Sklearn API.
|
||||
On data transformers it returns transformed output data
|
||||
Container for ONNX models mirroring Sklearn anomaly detection API.
|
||||
"""
|
||||
return self._run(self._predict_proba, *inputs, reshape=True)
|
||||
|
||||
def _decision_function(self, *inputs):
|
||||
assert len(self._output_names) == 2
|
||||
|
||||
named_inputs = self._get_named_inputs(inputs)
|
||||
|
||||
return np.array(self._session.run([self._output_names[1]], named_inputs)[0]).flatten()
|
||||
|
||||
|
||||
class ONNXSklearnContainerAnomalyDetection(ONNXSklearnContainerRegression):
|
||||
# TVM containers.
|
||||
class TVMSklearnContainer(SklearnContainer):
|
||||
"""
|
||||
Container mirroring Sklearn anomaly detection API.
|
||||
Base container for TVM models.
|
||||
The container allows to mirror the Sklearn API.
|
||||
"""
|
||||
|
||||
def __init__(self, model, n_threads=None, batch_size=None, extra_config={}):
|
||||
super(ONNXSklearnContainerAnomalyDetection, self).__init__(
|
||||
model, n_threads, batch_size, is_regression=False, is_anomaly_detection=True, extra_config=extra_config
|
||||
)
|
||||
super(TVMSklearnContainer, self).__init__(model, n_threads, batch_size, extra_config=extra_config)
|
||||
|
||||
assert len(self._output_names) == 2
|
||||
assert tvm_installed()
|
||||
import tvm
|
||||
|
||||
self._ctx = self._extra_config[constants.TVM_CONTEXT]
|
||||
self._input_names = self._extra_config[constants.TVM_INPUT_NAMES]
|
||||
self._remainder_model = None
|
||||
if constants.TVM_REMAINDER_MODEL in self._extra_config:
|
||||
self._remainder_model = self._extra_config[constants.TVM_REMAINDER_MODEL]
|
||||
self._to_tvm_array = lambda x: tvm.nd.array(x, self._ctx)
|
||||
|
||||
os.environ["TVM_NUM_THREADS"] = str(self._n_threads)
|
||||
|
||||
def _to_tvm_tensor(self, *inputs):
|
||||
return {self._input_names[i]: self._to_tvm_array(inputs[i]) for i in range(len(inputs))}
|
||||
|
||||
|
||||
class TVMSklearnContainerTransformer(TVMSklearnContainer, SklearnContainerTransformer):
|
||||
"""
|
||||
Container for TVM models mirroring Sklearn transformers API.
|
||||
"""
|
||||
|
||||
def _transform(self, *inputs):
|
||||
if self._last_iteration and self._remainder_model is not None:
|
||||
self._remainder_model.run(**self._to_tvm_tensor(*inputs))
|
||||
return self._remainder_model.get_output(0).asnumpy()
|
||||
self.model.run(**self._to_tvm_tensor(*inputs))
|
||||
return self.model.get_output(0).asnumpy()
|
||||
|
||||
|
||||
class TVMSklearnContainerRegression(TVMSklearnContainer, SklearnContainerRegression):
|
||||
"""
|
||||
Container for TVM models mirroring Sklearn regressors API.
|
||||
"""
|
||||
|
||||
def _predict(self, *inputs):
|
||||
if self._last_iteration and self._remainder_model is not None:
|
||||
self._remainder_model.run(**self._to_tvm_tensor(*inputs))
|
||||
return self._remainder_model.get_output(0).asnumpy().ravel()
|
||||
self.model.run(**self._to_tvm_tensor(*inputs))
|
||||
return self.model.get_output(0).asnumpy().ravel()
|
||||
|
||||
|
||||
class TVMSklearnContainerClassification(TVMSklearnContainerRegression, SklearnContainerClassification):
|
||||
"""
|
||||
Container for TVM models mirroring Sklearn classifiers API.
|
||||
"""
|
||||
|
||||
def _predict_proba(self, *inputs):
|
||||
if self._last_iteration and self._remainder_model is not None:
|
||||
self._remainder_model.run(**self._to_tvm_tensor(*inputs))
|
||||
return self._remainder_model.get_output(1).asnumpy()
|
||||
self.model.run(**self._to_tvm_tensor(*inputs))
|
||||
return self.model.get_output(1).asnumpy()
|
||||
|
||||
|
||||
class TVMSklearnContainerAnomalyDetection(TVMSklearnContainerRegression, SklearnContainerAnomalyDetection):
|
||||
"""
|
||||
Container for TVM models mirroring Sklearn anomaly detection API.
|
||||
"""
|
||||
|
||||
def _decision_function(self, *inputs):
|
||||
"""
|
||||
Utility functions used to emulate the behavior of the Sklearn API.
|
||||
On anomaly detection (e.g. isolation forest) returns the decision function scores.
|
||||
"""
|
||||
named_inputs = self._get_named_inputs(inputs)
|
||||
|
||||
scores = np.array(self._session.run([self._output_names[1]], named_inputs)[0]).flatten()
|
||||
# Backward compatibility for sklearn <= 0.21
|
||||
if constants.IFOREST_THRESHOLD in self._extra_config:
|
||||
scores += self._extra_config[constants.IFOREST_THRESHOLD]
|
||||
return scores
|
||||
|
||||
def decision_function(self, *inputs):
|
||||
"""
|
||||
Utility functions used to emulate the behavior of the Sklearn API.
|
||||
On data transformers it returns transformed output data
|
||||
"""
|
||||
return self._run(self._decision_function, *inputs)
|
||||
|
||||
def score_samples(self, *inputs):
|
||||
"""
|
||||
Utility functions used to emulate the behavior of the Sklearn API.
|
||||
On anomaly detection (e.g. isolation forest) returns the decision_function score plus offset_
|
||||
"""
|
||||
return self.decision_function(*inputs) + self._extra_config[constants.OFFSET]
|
||||
if self._last_iteration and self._remainder_model is not None:
|
||||
self._remainder_model.run(**self._to_tvm_tensor(*inputs))
|
||||
return self._remainder_model.get_output(1).asnumpy().ravel()
|
||||
else:
|
||||
self.model.run(**self._to_tvm_tensor(*inputs))
|
||||
return self.model.get_output(1).asnumpy().ravel()
|
||||
|
|
|
@ -15,6 +15,7 @@ from uuid import uuid4
|
|||
|
||||
from onnxconverter_common.registration import get_converter
|
||||
import onnx
|
||||
import timeit
|
||||
|
||||
from hummingbird.ml._container import (
|
||||
PyTorchSklearnContainerRegression,
|
||||
|
@ -29,8 +30,12 @@ from hummingbird.ml._container import (
|
|||
ONNXSklearnContainerClassification,
|
||||
ONNXSklearnContainerTransformer,
|
||||
ONNXSklearnContainerAnomalyDetection,
|
||||
TVMSklearnContainerRegression,
|
||||
TVMSklearnContainerClassification,
|
||||
TVMSklearnContainerTransformer,
|
||||
TVMSklearnContainerAnomalyDetection,
|
||||
)
|
||||
from hummingbird.ml._utils import pandas_installed, _get_device
|
||||
from hummingbird.ml._utils import pandas_installed, tvm_installed, _get_device
|
||||
from hummingbird.ml.exceptions import MissingConverter
|
||||
from hummingbird.ml.operator_converters import constants
|
||||
|
||||
|
@ -40,20 +45,40 @@ else:
|
|||
DataFrame = None
|
||||
|
||||
|
||||
def _jit_model(torch_model, trace_input, device, extra_config):
|
||||
"""
|
||||
Function used to convert an input pytorch model into torchscript.
|
||||
"""
|
||||
if device != "cpu":
|
||||
trace_input.to(device)
|
||||
return torch.jit.trace(torch_model, trace_input).eval()
|
||||
|
||||
|
||||
def _get_trace_input_from_test_input(input, batch_size):
|
||||
"""
|
||||
Utility function used to properly put the inputs into a format understandable by torch.
|
||||
If a not None batch_size is passed, the function generates a tracing input of batch_size.
|
||||
If input size % batch_size is not 0, a tracing input of the remainder is also generated.
|
||||
"""
|
||||
remainder = None
|
||||
if type(input) is tuple:
|
||||
if batch_size is not None:
|
||||
trace_input = tuple([torch.from_numpy(i)[0:batch_size, :] for i in input])
|
||||
if len(input) > 0 and input[0].shape[0] % batch_size != 0:
|
||||
remainder_size = input[0].shape[0] % batch_size
|
||||
remainder = tuple([torch.from_numpy(i)[0:remainder_size, :] for i in input])
|
||||
else:
|
||||
trace_input = tuple([torch.from_numpy(i) for i in input])
|
||||
else:
|
||||
trace_input = torch.from_numpy(input)
|
||||
if batch_size is not None:
|
||||
trace_input = trace_input[0:batch_size, :]
|
||||
return trace_input
|
||||
batch_input = trace_input[0:batch_size, :]
|
||||
remainder_size = len(input) % batch_size
|
||||
if remainder_size != 0:
|
||||
remainder = trace_input[0:remainder_size, :]
|
||||
trace_input = batch_input
|
||||
|
||||
return (trace_input, remainder)
|
||||
|
||||
|
||||
def convert(topology, backend, device, extra_config={}):
|
||||
|
@ -73,8 +98,16 @@ def convert(topology, backend, device, extra_config={}):
|
|||
assert backend is not None, "Cannot convert a Topology object into backend None."
|
||||
assert device is not None, "Cannot convert a Topology object into device None."
|
||||
|
||||
tvm_backend = None
|
||||
operator_map = {}
|
||||
|
||||
if tvm_installed():
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.contrib import graph_runtime
|
||||
|
||||
tvm_backend = tvm.__name__
|
||||
|
||||
for operator in topology.topological_operator_iterator():
|
||||
try:
|
||||
converter = get_converter(operator.type)
|
||||
|
@ -128,12 +161,12 @@ def convert(topology, backend, device, extra_config={}):
|
|||
output_model_name = str(uuid4().hex) + ".onnx"
|
||||
|
||||
# Put the tracing test input into the right format.
|
||||
trace_input = _get_trace_input_from_test_input(extra_config[constants.TEST_INPUT], batch_size)
|
||||
batch_trace_input, _ = _get_trace_input_from_test_input(extra_config[constants.TEST_INPUT], batch_size)
|
||||
|
||||
# Generate the ONNX models
|
||||
torch.onnx.export(
|
||||
torch_model,
|
||||
trace_input,
|
||||
batch_trace_input,
|
||||
output_model_name,
|
||||
input_names=topology.raw_model.input_names,
|
||||
output_names=topology.raw_model.output_names,
|
||||
|
@ -185,6 +218,77 @@ def convert(topology, backend, device, extra_config={}):
|
|||
return num_fixed
|
||||
|
||||
fix_graph(hb_model.graph)
|
||||
elif backend == tvm_backend:
|
||||
# First we need to generate the torchscript model.
|
||||
batch_trace_input, remainder_trace_input = _get_trace_input_from_test_input(
|
||||
extra_config[constants.TEST_INPUT], batch_size
|
||||
)
|
||||
ts_model = _jit_model(torch_model, batch_trace_input, "cpu", extra_config)
|
||||
if remainder_trace_input is not None:
|
||||
remainder_ts_model = _jit_model(torch_model, remainder_trace_input, "cpu", extra_config)
|
||||
|
||||
# Generate the test input in the TVM format. In case we have a remainder beyond the batch, generate a remainder test input as well.
|
||||
test_input = [
|
||||
(
|
||||
topology.raw_model.input_names[i],
|
||||
batch_trace_input[i].shape if type(batch_trace_input) is tuple else batch_trace_input.shape,
|
||||
)
|
||||
for i in range(len(topology.raw_model.input_names))
|
||||
]
|
||||
if remainder_trace_input is not None:
|
||||
remainder_test_input = [
|
||||
(
|
||||
topology.raw_model.input_names[i],
|
||||
remainder_trace_input[i].shape if type(remainder_trace_input) is tuple else remainder_trace_input.shape,
|
||||
)
|
||||
for i in range(len(topology.raw_model.input_names))
|
||||
]
|
||||
|
||||
# Pick the proper target.
|
||||
if device == "cuda":
|
||||
target = tvm.target.cuda()
|
||||
ctx = tvm.gpu()
|
||||
elif device == "cpu":
|
||||
target = "llvm"
|
||||
ctx = tvm.cpu()
|
||||
elif "llvm" in device:
|
||||
target = device
|
||||
ctx = tvm.cpu()
|
||||
else:
|
||||
raise RuntimeError("Device {} not recognized".format(device))
|
||||
|
||||
# Get configuration parameters.
|
||||
config = {}
|
||||
if constants.TVM_MAX_FUSE_DEPTH in extra_config:
|
||||
config["relay.FuseOps.max_depth"] = extra_config[constants.TVM_MAX_FUSE_DEPTH]
|
||||
else:
|
||||
# 50 is a good depth for operator fusion. More than that will probably hurt performance.
|
||||
# https://github.com/microsoft/hummingbird/issues/232#issuecomment-697979508
|
||||
config["relay.FuseOps.max_depth"] = 50
|
||||
|
||||
# Create the relay version of the model.
|
||||
model, params = relay.frontend.from_pytorch(ts_model, test_input)
|
||||
if remainder_trace_input is not None:
|
||||
remainder_model, remainder_params = relay.frontend.from_pytorch(remainder_ts_model, remainder_test_input)
|
||||
|
||||
# Generate the model. We set opt_level=3 to enable all optimizations.
|
||||
with tvm.transform.PassContext(opt_level=3, config=config):
|
||||
graph, lib, params = relay.build(model, target=target, params=params)
|
||||
tvm_model = graph_runtime.create(graph, lib, ctx)
|
||||
tvm_model.set_input(**params)
|
||||
if remainder_trace_input is not None:
|
||||
with tvm.transform.PassContext(opt_level=3, config=config):
|
||||
graph, lib, params = relay.build(remainder_model, target=target, params=remainder_params)
|
||||
tvm_remainder_model = graph_runtime.create(graph, lib, ctx)
|
||||
tvm_remainder_model.set_input(**params)
|
||||
|
||||
# In the container we will be using the context to properly configure the input tensors.
|
||||
extra_config[constants.TVM_CONTEXT] = ctx
|
||||
extra_config[constants.TVM_INPUT_NAMES] = topology.raw_model.input_names
|
||||
if remainder_trace_input is not None:
|
||||
extra_config[constants.TVM_REMAINDER_MODEL] = tvm_remainder_model
|
||||
|
||||
hb_model = tvm_model
|
||||
else:
|
||||
# Set the device for the model.
|
||||
if device != "cpu":
|
||||
|
@ -193,18 +297,19 @@ def convert(topology, backend, device, extra_config={}):
|
|||
|
||||
# If the backend is tochscript, jit the model.
|
||||
if backend == torch.jit.__name__:
|
||||
trace_input = _get_trace_input_from_test_input(extra_config[constants.TEST_INPUT], batch_size)
|
||||
trace_input, _ = _get_trace_input_from_test_input(extra_config[constants.TEST_INPUT], batch_size)
|
||||
if device != "cpu":
|
||||
trace_input.to(device)
|
||||
torch_model = torch.jit.trace(torch_model, trace_input).eval()
|
||||
torch.jit.optimized_execution(torch_model)
|
||||
|
||||
hb_model = torch_model
|
||||
|
||||
# Return if the container is not needed.
|
||||
if constants.CONTAINER in extra_config and not extra_config[constants.CONTAINER]:
|
||||
return hb_model
|
||||
|
||||
# We scan the operators backwards until we find an operator with a defined type
|
||||
# We scan the operators backwards until we find an operator with a defined type.
|
||||
# This is necessary because ONNX models can have arbitrary operators doing casting, reshaping etc.
|
||||
idx = len(operators) - 1
|
||||
while (
|
||||
|
@ -232,12 +337,15 @@ def convert(topology, backend, device, extra_config={}):
|
|||
if idx < 0:
|
||||
idx = tmp_idx
|
||||
|
||||
# Get the proper container type.
|
||||
if operator_map[operators[idx].full_name].regression:
|
||||
# We are doing a regression task.
|
||||
if backend == torch.jit.__name__:
|
||||
container = TorchScriptSklearnContainerRegression
|
||||
elif backend == onnx.__name__:
|
||||
container = ONNXSklearnContainerRegression
|
||||
elif backend == tvm_backend:
|
||||
container = TVMSklearnContainerRegression
|
||||
else:
|
||||
container = PyTorchSklearnContainerRegression
|
||||
elif operator_map[operators[idx].full_name].anomaly_detection:
|
||||
|
@ -246,6 +354,8 @@ def convert(topology, backend, device, extra_config={}):
|
|||
container = TorchScriptSklearnContainerAnomalyDetection
|
||||
elif backend == onnx.__name__:
|
||||
container = ONNXSklearnContainerAnomalyDetection
|
||||
elif backend == tvm_backend:
|
||||
container = TVMSklearnContainerAnomalyDetection
|
||||
else:
|
||||
container = PyTorchSklearnContainerAnomalyDetection
|
||||
elif operator_map[operators[idx].full_name].transformer:
|
||||
|
@ -254,6 +364,8 @@ def convert(topology, backend, device, extra_config={}):
|
|||
container = TorchScriptSklearnContainerTransformer
|
||||
elif backend == onnx.__name__:
|
||||
container = ONNXSklearnContainerTransformer
|
||||
elif backend == tvm_backend:
|
||||
container = TVMSklearnContainerTransformer
|
||||
else:
|
||||
container = PyTorchSklearnContainerTransformer
|
||||
else:
|
||||
|
@ -262,6 +374,8 @@ def convert(topology, backend, device, extra_config={}):
|
|||
container = TorchScriptSklearnContainerClassification
|
||||
elif backend == onnx.__name__:
|
||||
container = ONNXSklearnContainerClassification
|
||||
elif backend == tvm_backend:
|
||||
container = TVMSklearnContainerClassification
|
||||
else:
|
||||
container = PyTorchSklearnContainerClassification
|
||||
|
||||
|
|
|
@ -113,6 +113,18 @@ def xgboost_installed():
|
|||
return True
|
||||
|
||||
|
||||
def tvm_installed():
|
||||
"""
|
||||
Checks that *TVM* is available.
|
||||
"""
|
||||
try:
|
||||
import tvm
|
||||
from tvm import relay
|
||||
except ImportError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def pandas_installed():
|
||||
"""
|
||||
Checks that *Pandas* is available.
|
||||
|
|
|
@ -22,6 +22,7 @@ from ._utils import (
|
|||
sparkml_installed,
|
||||
is_pandas_dataframe,
|
||||
is_spark_dataframe,
|
||||
tvm_installed,
|
||||
)
|
||||
from .exceptions import MissingConverter, MissingBackend
|
||||
from .supported import backends
|
||||
|
@ -70,7 +71,13 @@ def _supported_backend_check_config(model, backend, extra_config):
|
|||
import onnx
|
||||
import torch
|
||||
|
||||
if backend is torch.jit.__name__ and constants.TEST_INPUT not in extra_config:
|
||||
tvm_backend = None
|
||||
if tvm_installed():
|
||||
import tvm
|
||||
|
||||
tvm_backend = tvm.__name__
|
||||
|
||||
if (backend == torch.jit.__name__ or backend == tvm_backend) and constants.TEST_INPUT not in extra_config:
|
||||
raise RuntimeError("Backend {} requires test inputs. Please pass some test input to the convert.".format(backend))
|
||||
|
||||
|
||||
|
@ -262,7 +269,7 @@ def convert(model, backend, test_input=None, device="cpu", extra_config={}):
|
|||
For *LightGBM* and *XGBoost* currently only the Sklearn API is supported.
|
||||
The detailed list of models and backends can be found at `hummingbird.ml.supported`.
|
||||
The *onnx* backend requires either a test_input of a the initial types set through the exta_config parameter.
|
||||
The *torch.jit* backend requires a test_input.
|
||||
The *torch.jit* and *tvm* backends requires a test_input.
|
||||
[Sklearn]: https://scikit-learn.org/
|
||||
[LightGBM]: https://lightgbm.readthedocs.io/
|
||||
[XGBoost]: https://xgboost.readthedocs.io/
|
||||
|
@ -276,8 +283,8 @@ def convert(model, backend, test_input=None, device="cpu", extra_config={}):
|
|||
backend: The target for the conversion
|
||||
test_input: Some input data used to trace the model execution.
|
||||
Multiple inputs can be passed as `tuple` objects or pandas Dataframes.
|
||||
When possible, (`numpy`)`arrays` are suggesed.
|
||||
device: The target device the model should be run. This parameter is only used by the *torch** backends, and
|
||||
When possible, (`numpy`)`arrays` are suggested.
|
||||
device: The target device the model should be run. This parameter is only used by the *torch** backends and *tvm*, and
|
||||
the devices supported are the one supported by PyTorch, i.e., 'cpu' or 'cuda'.
|
||||
extra_config: Extra configurations to be used by the individual operator converters.
|
||||
The set of supported extra configurations can be found at `hummingbird.ml.supported`
|
||||
|
|
|
@ -17,7 +17,7 @@ class LinearModel(BaseOperator, torch.nn.Module):
|
|||
def __init__(self, coefficients, intercepts, device, classes=[0], multi_class=None, is_linear_regression=False):
|
||||
super(LinearModel, self).__init__()
|
||||
self.coefficients = torch.nn.Parameter(torch.from_numpy(coefficients), requires_grad=False)
|
||||
self.intercepts = torch.nn.Parameter(torch.from_numpy(intercepts), requires_grad=False)
|
||||
self.intercepts = torch.nn.Parameter(torch.from_numpy(intercepts).view(-1), requires_grad=False)
|
||||
self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False)
|
||||
self.multi_class = multi_class
|
||||
self.regression = is_linear_regression
|
||||
|
|
|
@ -19,7 +19,9 @@ class BernoulliNBModel(BaseOperator, torch.nn.Module):
|
|||
super(BernoulliNBModel, self).__init__()
|
||||
self.classification = True
|
||||
self.binarize = binarize
|
||||
self.jll_calc_bias = torch.nn.Parameter(torch.from_numpy(jll_calc_bias.astype("float32")), requires_grad=False)
|
||||
self.jll_calc_bias = torch.nn.Parameter(
|
||||
torch.from_numpy(jll_calc_bias.astype("float32")).view(-1), requires_grad=False
|
||||
)
|
||||
self.feature_log_prob_minus_neg_prob = torch.nn.Parameter(
|
||||
torch.from_numpy(feature_log_prob_minus_neg_prob.astype("float32")), requires_grad=False
|
||||
)
|
||||
|
|
|
@ -25,10 +25,10 @@ class Scaler(BaseOperator, torch.nn.Module):
|
|||
self.scale = scale
|
||||
|
||||
if offset is not None:
|
||||
self.offset = torch.nn.Parameter(torch.FloatTensor([offset]), requires_grad=False)
|
||||
self.offset = torch.nn.Parameter(torch.DoubleTensor([offset]), requires_grad=False)
|
||||
|
||||
if scale is not None:
|
||||
self.scale = torch.nn.Parameter(torch.FloatTensor([scale]), requires_grad=False)
|
||||
self.scale = torch.nn.Parameter(torch.DoubleTensor([scale]), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
if self.offset is not None:
|
||||
|
|
|
@ -408,6 +408,6 @@ def convert_decision_ensemble_tree_common(
|
|||
for tree_param in tree_parameters
|
||||
]
|
||||
if tree_type == TreeImpl.tree_trav:
|
||||
return TreeTraversalDecisionTreeImpl(net_parameters, max_depth, n_features, classes)
|
||||
return TreeTraversalDecisionTreeImpl(net_parameters, max_depth, n_features, classes, extra_config)
|
||||
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav
|
||||
return PerfectTreeTraversalDecisionTreeImpl(net_parameters, max_depth, n_features, classes)
|
||||
|
|
|
@ -162,7 +162,7 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
|
|||
|
||||
if self.anomaly_detection:
|
||||
# Select the class (-1 if negative) and return the score.
|
||||
return torch.where(x < 0, self.classes[0], self.classes[1]), x
|
||||
return torch.where(x.view(-1) < 0, self.classes[0], self.classes[1]), x
|
||||
|
||||
if self.perform_class_select:
|
||||
return torch.index_select(self.classes, 0, torch.argmax(x, dim=1)), x
|
||||
|
@ -175,7 +175,12 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
Class implementing the Tree Traversal strategy in PyTorch for tree-base models.
|
||||
"""
|
||||
|
||||
def __init__(self, tree_parameters, max_depth, n_features, classes, n_classes=None, **kwargs):
|
||||
def _expand_indexes(self, batch_size):
|
||||
indexes = self.nodes_offset
|
||||
indexes = indexes.expand(batch_size, self.num_trees)
|
||||
return indexes.reshape(-1)
|
||||
|
||||
def __init__(self, tree_parameters, max_depth, n_features, classes, n_classes=None, extra_config={}, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
|
@ -183,6 +188,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
n_classes: The total number of used classes
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(TreeTraversalTreeImpl, self).__init__(tree_parameters, n_features, classes, n_classes, **kwargs)
|
||||
|
||||
|
@ -192,8 +198,8 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
self.num_trees = len(tree_parameters)
|
||||
self.num_nodes = max([len(tree_parameter[1]) for tree_parameter in tree_parameters])
|
||||
|
||||
lefts = np.zeros((self.num_trees, self.num_nodes), dtype=np.float32)
|
||||
rights = np.zeros((self.num_trees, self.num_nodes), dtype=np.float32)
|
||||
lefts = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64)
|
||||
rights = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64)
|
||||
|
||||
features = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64)
|
||||
thresholds = np.zeros((self.num_trees, self.num_nodes), dtype=np.float32)
|
||||
|
@ -220,9 +226,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
return x
|
||||
|
||||
def forward(self, x):
|
||||
indexes = self.nodes_offset
|
||||
indexes = indexes.expand(x.size()[0], self.num_trees)
|
||||
indexes = indexes.reshape(-1)
|
||||
indexes = self._expand_indexes(x.size()[0])
|
||||
|
||||
for _ in range(self.max_tree_depth):
|
||||
tree_nodes = indexes
|
||||
|
@ -246,7 +250,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
|
||||
if self.anomaly_detection:
|
||||
# Select the class (-1 if negative) and return the score.
|
||||
return torch.where(output < 0, self.classes[0], self.classes[1]), output
|
||||
return torch.where(output.view(-1) < 0, self.classes[0], self.classes[1]), output
|
||||
|
||||
if self.perform_class_select:
|
||||
return torch.index_select(self.classes, 0, torch.argmax(output, dim=1)), output
|
||||
|
@ -325,11 +329,9 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
for nodes, biases in zip(self.nodes, self.biases):
|
||||
gather_indices = torch.index_select(nodes, 0, prev_indices).view(-1, self.num_trees)
|
||||
features = torch.gather(x, 1, gather_indices).view(-1)
|
||||
prev_indices = factor * prev_indices + torch.ge(features, torch.index_select(biases, 0, prev_indices)).long().view(
|
||||
-1
|
||||
)
|
||||
prev_indices = factor * prev_indices + torch.ge(features, torch.index_select(biases, 0, prev_indices)).long()
|
||||
|
||||
output = torch.index_select(self.leaf_nodes, 0, prev_indices.view(-1)).view(-1, self.num_trees, self.n_classes)
|
||||
output = torch.index_select(self.leaf_nodes, 0, prev_indices).view(-1, self.num_trees, self.n_classes)
|
||||
|
||||
output = self.aggregation(output)
|
||||
|
||||
|
@ -338,7 +340,7 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
|
||||
if self.anomaly_detection:
|
||||
# Select the class (-1 if negative) and return the score.
|
||||
return torch.where(output < 0, self.classes[0], self.classes[1]), output
|
||||
return torch.where(output.view(-1) < 0, self.classes[0], self.classes[1]), output
|
||||
|
||||
if self.perform_class_select:
|
||||
return torch.index_select(self.classes, 0, torch.argmax(output, dim=1)), output
|
||||
|
@ -414,15 +416,18 @@ class TreeTraversalDecisionTreeImpl(TreeTraversalTreeImpl):
|
|||
Class implementing the Tree Traversal strategy in PyTorch for decision tree models.
|
||||
"""
|
||||
|
||||
def __init__(self, tree_parameters, max_depth, n_features, classes=None):
|
||||
def __init__(self, tree_parameters, max_depth, n_features, classes=None, extra_config={}):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
max_depth: The maximum tree-depth in the model
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(TreeTraversalDecisionTreeImpl, self).__init__(tree_parameters, max_depth, n_features, classes)
|
||||
super(TreeTraversalDecisionTreeImpl, self).__init__(
|
||||
tree_parameters, max_depth, n_features, classes, extra_config=extra_config
|
||||
)
|
||||
|
||||
def aggregation(self, x):
|
||||
output = x.sum(1)
|
||||
|
@ -498,7 +503,7 @@ class TreeTraversalGBDTImpl(TreeTraversalTreeImpl):
|
|||
classes: The classes used for classification. None if implementing a regression model
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(TreeTraversalGBDTImpl, self).__init__(tree_parameters, max_detph, n_features, classes, 1)
|
||||
super(TreeTraversalGBDTImpl, self).__init__(tree_parameters, max_detph, n_features, classes, 1, extra_config)
|
||||
|
||||
self.n_gbdt_classes = 1
|
||||
self.post_transform = lambda x: x
|
||||
|
|
|
@ -38,6 +38,15 @@ ONNX_INITIALIZERS = "onnx_initializers"
|
|||
ONNX_INPUTS = "onnx_inputs"
|
||||
"""The input of the onnx model."""
|
||||
|
||||
TVM_CONTEXT = "tvm_context"
|
||||
"""The context for TVM containing information on the target."""
|
||||
|
||||
TVM_INPUT_NAMES = "tvm_input_names"
|
||||
"""TVM expects named inputs. This is used to set the names for the inputs."""
|
||||
|
||||
TVM_REMAINDER_MODEL = "tvm_remainder_model"
|
||||
"""TVM is statically compiled and when batching we may need to use a different model for the remainder part of the records."""
|
||||
|
||||
TEST_INPUT = "test_input"
|
||||
"""The test input data for models that need to be traced."""
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ class SVC(BaseOperator, torch.nn.Module):
|
|||
self.coef0 = coef0
|
||||
self.n_features = sv.shape[1]
|
||||
self.a = a
|
||||
self.b = torch.nn.Parameter(torch.nn.Parameter(torch.from_numpy(b.reshape(1, -1)).double()), requires_grad=False)
|
||||
self.b = torch.nn.Parameter(torch.from_numpy(b.reshape(1, -1)).double(), requires_grad=False)
|
||||
self.start = [sum(nv[:i]) for i in range(len(nv))]
|
||||
self.end = [self.start[i] + nv[i] for i in range(len(nv))]
|
||||
self.len_nv = len(nv)
|
||||
|
@ -52,7 +52,7 @@ class SVC(BaseOperator, torch.nn.Module):
|
|||
# using quadratic expansion--susseptible to rounding-off errors
|
||||
# http://www.robots.ox.ac.uk/~albanie/notes/Euclidean_distance_trick.pdf
|
||||
x_norm = -self.gamma * (x ** 2).sum(1).view(-1, 1)
|
||||
k = torch.exp(x_norm + self.sv_norm + 2.0 * self.gamma * torch.mm(x, self.sv_t))
|
||||
k = torch.exp(x_norm + self.sv_norm + 2.0 * self.gamma * torch.mm(x, self.sv_t).double())
|
||||
elif self.kernel == "sigmoid":
|
||||
k = torch.sigmoid(self.gamma * torch.mm(x, self.sv_t) + self.coef0)
|
||||
else: # poly kernel
|
||||
|
@ -69,10 +69,10 @@ class SVC(BaseOperator, torch.nn.Module):
|
|||
class_ids = torch.gt(c, 0.0).int().flatten()
|
||||
else:
|
||||
votes = torch.where(c > 0, self.true_classes, self.false_classes)
|
||||
# TODO mode is still not implemented for GPU backend
|
||||
# TODO mode is still not implemented for GPU backend.
|
||||
votes = votes.data.cpu()
|
||||
class_ids, _ = torch.mode(votes, dim=1)
|
||||
# no class probabilities in SVC
|
||||
# No class probabilities in SVC.
|
||||
if self.perform_class_select:
|
||||
temp = torch.index_select(self.classes, 0, class_ids.long())
|
||||
return temp, temp
|
||||
|
|
|
@ -10,7 +10,8 @@ All operators, backends, and configurations settings supported in Hummingbird ar
|
|||
**Supported Backends**
|
||||
PyTorch,
|
||||
TorchScript,
|
||||
ONNX
|
||||
ONNX,
|
||||
TVM
|
||||
|
||||
**Supported Operators (scikit-learn)**
|
||||
BernoulliNB,
|
||||
|
@ -90,6 +91,7 @@ from ._utils import (
|
|||
lightgbm_installed,
|
||||
xgboost_installed,
|
||||
onnx_runtime_installed,
|
||||
tvm_installed,
|
||||
sparkml_installed,
|
||||
)
|
||||
|
||||
|
@ -319,6 +321,11 @@ def _build_backend_map():
|
|||
|
||||
backends[onnx.__name__] = onnx.__name__
|
||||
|
||||
if tvm_installed():
|
||||
import tvm
|
||||
|
||||
backends[tvm.__name__] = tvm.__name__
|
||||
|
||||
return backends
|
||||
|
||||
|
||||
|
@ -429,6 +436,11 @@ ONNX_OUTPUT_MODEL_NAME = "onnx_model_name"
|
|||
ONNX_TARGET_OPSET = "onnx_target_opset"
|
||||
"""For ONNX models we can set the target opset to use. 11 by default."""
|
||||
|
||||
TVM_MAX_FUSE_DEPTH = "tvm_max_fuse_depth"
|
||||
"""For TVM we can fix the number of operations that will be fused.
|
||||
If not set, compilation may take forever (https://github.com/microsoft/hummingbird/issues/232).
|
||||
By default Hummingbird uses a max_fuse_depth of 50, but this can be override using this parameter."""
|
||||
|
||||
INPUT_NAMES = "input_names"
|
||||
"""Set the names of the inputs. Assume that the numbers of inputs_names is equal to the number of inputs."""
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from onnxconverter_common.data_types import (
|
|||
)
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed
|
||||
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed, tvm_installed
|
||||
from hummingbird.ml.exceptions import MissingBackend
|
||||
|
||||
if onnx_ml_tools_installed():
|
||||
|
@ -96,7 +96,24 @@ class TestBackends(unittest.TestCase):
|
|||
# Test torcscript requires test_input
|
||||
self.assertRaises(RuntimeError, hummingbird.ml.convert, model, "torch.jit")
|
||||
|
||||
# Test onnx no test_data, float input
|
||||
# Test TVM requires test_data
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
|
||||
def test_tvm_test_data(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
max_depth = 10
|
||||
num_classes = 2
|
||||
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
# Test tvm requires test_input
|
||||
self.assertRaises(RuntimeError, hummingbird.ml.convert, model, "tvm")
|
||||
|
||||
# Test onnx requires test_data or initial_types
|
||||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
|
||||
)
|
||||
|
|
|
@ -17,7 +17,13 @@ from sklearn.pipeline import Pipeline
|
|||
import torch
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed, pandas_installed, lightgbm_installed
|
||||
from hummingbird.ml._utils import (
|
||||
onnx_ml_tools_installed,
|
||||
onnx_runtime_installed,
|
||||
pandas_installed,
|
||||
lightgbm_installed,
|
||||
tvm_installed,
|
||||
)
|
||||
from hummingbird.ml import constants
|
||||
|
||||
if lightgbm_installed():
|
||||
|
@ -384,6 +390,156 @@ class TestExtraConf(unittest.TestCase):
|
|||
np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06)
|
||||
np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Test tvm transform with batching.
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
|
||||
def test_tvm_batch_transform(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
model = StandardScaler(with_mean=True, with_std=True)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
|
||||
model.fit(X)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
|
||||
|
||||
self.assertIsNotNone(hb_model)
|
||||
np.testing.assert_allclose(model.transform(X), hb_model.transform(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Test tvm regression with batching.
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
|
||||
def test_tvm_regression_batch(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
max_depth = 10
|
||||
num_classes = 2
|
||||
model = GradientBoostingRegressor(n_estimators=10, max_depth=max_depth)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
|
||||
|
||||
self.assertIsNotNone(hb_model)
|
||||
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Test tvm classification with batching.
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
|
||||
def test_tvm_classification_batch(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
max_depth = 10
|
||||
num_classes = 2
|
||||
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
|
||||
|
||||
self.assertIsNotNone(hb_model)
|
||||
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
np.testing.assert_allclose(model.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Test tvm iforest with batching.
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
|
||||
def test_tvm_iforest_batch(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
num_classes = 2
|
||||
model = IsolationForest(n_estimators=10, max_samples=2)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
|
||||
|
||||
self.assertIsNotNone(hb_model)
|
||||
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06)
|
||||
np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Test tvm transform with batching and uneven numer of records.
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
|
||||
def test_tvm_batch_remainder_transform(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
model = StandardScaler(with_mean=True, with_std=True)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(105, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
|
||||
model.fit(X)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
|
||||
|
||||
self.assertIsNotNone(hb_model)
|
||||
np.testing.assert_allclose(model.transform(X), hb_model.transform(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Test tvm regression with batching and uneven numer of records.
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
|
||||
def test_tvm_regression_remainder_batch(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
max_depth = 10
|
||||
num_classes = 2
|
||||
model = GradientBoostingRegressor(n_estimators=10, max_depth=max_depth)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(105, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=105)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
|
||||
|
||||
self.assertIsNotNone(hb_model)
|
||||
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Test tvm classification with batching and uneven numer of records.
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
|
||||
def test_tvm_classification_remainder_batch(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
max_depth = 10
|
||||
num_classes = 2
|
||||
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(105, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=105)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
|
||||
|
||||
self.assertIsNotNone(hb_model)
|
||||
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
np.testing.assert_allclose(model.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Test tvm iforest with batching and uneven numer of records.
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
|
||||
def test_tvm_iforest_remainder_batch(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
num_classes = 2
|
||||
model = IsolationForest(n_estimators=10, max_samples=2)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(105, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=105)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.BATCH_SIZE: 10})
|
||||
|
||||
self.assertIsNotNone(hb_model)
|
||||
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06)
|
||||
np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Test batch with pandas.
|
||||
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
|
||||
def test_pandas_batch(self):
|
||||
|
@ -391,8 +547,8 @@ class TestExtraConf(unittest.TestCase):
|
|||
|
||||
max_depth = 10
|
||||
iris = datasets.load_iris()
|
||||
X = iris.data[:, :3]
|
||||
y = iris.target
|
||||
X = iris.data[:149, :3]
|
||||
y = iris.target[:149]
|
||||
columns = ["vA", "vB", "vC"]
|
||||
X_train = pandas.DataFrame(X, columns=columns)
|
||||
|
||||
|
@ -413,15 +569,15 @@ class TestExtraConf(unittest.TestCase):
|
|||
pipeline.predict_proba(X_train), torch_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
|
||||
)
|
||||
|
||||
# Test batch with pandas.
|
||||
# Test batch with pandas ts.
|
||||
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
|
||||
def test_pandas_batch_ts(self):
|
||||
import pandas
|
||||
|
||||
max_depth = 10
|
||||
iris = datasets.load_iris()
|
||||
X = iris.data[:, :3]
|
||||
y = iris.target
|
||||
X = iris.data[:149, :3]
|
||||
y = iris.target[:149]
|
||||
columns = ["vA", "vB", "vC"]
|
||||
X_train = pandas.DataFrame(X, columns=columns)
|
||||
|
||||
|
@ -442,7 +598,7 @@ class TestExtraConf(unittest.TestCase):
|
|||
pipeline.predict_proba(X_train), torch_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
|
||||
)
|
||||
|
||||
# Test batch with pandas.
|
||||
# Test batch with pandas onnx.
|
||||
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
|
||||
@unittest.skipIf(not onnx_runtime_installed(), reason="ONNXML test require ONNX and ORT")
|
||||
def test_pandas_batch_onnx(self):
|
||||
|
@ -450,8 +606,8 @@ class TestExtraConf(unittest.TestCase):
|
|||
|
||||
max_depth = 10
|
||||
iris = datasets.load_iris()
|
||||
X = iris.data[:, :3]
|
||||
y = iris.target
|
||||
X = iris.data[:149, :3]
|
||||
y = iris.target[:149]
|
||||
columns = ["vA", "vB", "vC"]
|
||||
X_train = pandas.DataFrame(X, columns=columns)
|
||||
|
||||
|
@ -472,7 +628,7 @@ class TestExtraConf(unittest.TestCase):
|
|||
pipeline.predict_proba(X_train), hb_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
|
||||
)
|
||||
|
||||
# Test batch with pandas.
|
||||
# Test batch with pandas from onnxml.
|
||||
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
|
||||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
|
||||
|
@ -515,6 +671,36 @@ class TestExtraConf(unittest.TestCase):
|
|||
pipeline.predict_proba(X_train), hb_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
|
||||
)
|
||||
|
||||
# Test batch with pandas tvm.
|
||||
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM")
|
||||
def test_pandas_batch_tvm(self):
|
||||
import pandas
|
||||
|
||||
max_depth = 10
|
||||
iris = datasets.load_iris()
|
||||
X = iris.data[:149, :3]
|
||||
y = iris.target[:149]
|
||||
columns = ["vA", "vB", "vC"]
|
||||
X_train = pandas.DataFrame(X, columns=columns)
|
||||
|
||||
pipeline = Pipeline(
|
||||
steps=[
|
||||
("preprocessor", ColumnTransformer(transformers=[], remainder="passthrough",)),
|
||||
("classifier", GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)),
|
||||
]
|
||||
)
|
||||
|
||||
pipeline.fit(X_train, y)
|
||||
|
||||
hb_model = hummingbird.ml.convert(pipeline, "tvm", X_train, extra_config={constants.BATCH_SIZE: 10})
|
||||
|
||||
self.assertTrue(hb_model is not None)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
pipeline.predict_proba(X_train), hb_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
|
||||
)
|
||||
|
||||
# Check converter with model name set as extra_config.
|
||||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
|
||||
|
@ -539,6 +725,21 @@ class TestExtraConf(unittest.TestCase):
|
|||
|
||||
assert onnx_model.model.graph.name == model_name
|
||||
|
||||
# Test max fuse depth configuration in TVM
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
|
||||
def test_xgb_classifier_converter_tvm(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
X = [[0, 1], [1, 1], [2, 0]]
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.array([100, -10, 50], dtype=np.float32)
|
||||
model = lgb.LGBMRegressor(n_estimators=3, min_child_samples=1)
|
||||
model.fit(X, y)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
self.assertIsNotNone(hb_model)
|
||||
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -7,7 +7,7 @@ import warnings
|
|||
import numpy as np
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml._utils import lightgbm_installed, onnx_runtime_installed
|
||||
from hummingbird.ml._utils import lightgbm_installed, onnx_runtime_installed, tvm_installed
|
||||
from tree_utils import gbdt_implementation_map
|
||||
|
||||
if lightgbm_installed():
|
||||
|
@ -301,8 +301,6 @@ class TestLGBMConverter(unittest.TestCase):
|
|||
@unittest.skipIf(not onnx_runtime_installed(), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS")
|
||||
@unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed")
|
||||
def test_lightgbm_onnx(self):
|
||||
import onnxruntime as ort
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
X = [[0, 1], [1, 1], [2, 0]]
|
||||
|
@ -314,10 +312,63 @@ class TestLGBMConverter(unittest.TestCase):
|
|||
# Create ONNX model
|
||||
onnx_model = hummingbird.ml.convert(model, "onnx", X)
|
||||
|
||||
# Get the predictions for the ONNX-ML model
|
||||
onnx_pred = onnx_model.predict(X)
|
||||
np.testing.assert_allclose(onnx_model.predict(X)[0].flatten(), model.predict(X))
|
||||
|
||||
np.testing.assert_allclose(onnx_pred[0].flatten(), model.predict(X))
|
||||
# TVM backend tests.
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_lightgbm_tvm_regressor(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
for tree_implementation in ["gemm", "tree_trav", "perf_tree_trav"]:
|
||||
X = [[0, 1], [1, 1], [2, 0]]
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.array([100, -10, 50], dtype=np.float32)
|
||||
model = lgb.LGBMRegressor(n_estimators=3, min_child_samples=1)
|
||||
model.fit(X, y)
|
||||
|
||||
# Create TVM model.
|
||||
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={"tree_implementation": tree_implementation})
|
||||
|
||||
# Check results.
|
||||
np.testing.assert_allclose(tvm_model.predict(X), model.predict(X))
|
||||
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM installed")
|
||||
def test_lightgbm_tvm_classifier(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
for tree_implementation in ["gemm", "tree_trav", "perf_tree_trav"]:
|
||||
X = [[0, 1], [1, 1], [2, 0]]
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.array([0, 1, 0], dtype=np.float32)
|
||||
model = lgb.LGBMClassifier(n_estimators=3, min_child_samples=1)
|
||||
model.fit(X, y)
|
||||
|
||||
# Create TVM model.
|
||||
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={"tree_implementation": tree_implementation})
|
||||
|
||||
# Check results.
|
||||
np.testing.assert_allclose(tvm_model.predict(X), model.predict(X))
|
||||
np.testing.assert_allclose(tvm_model.predict_proba(X), model.predict_proba(X))
|
||||
|
||||
# Test TVM with large input datasets.
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM installed")
|
||||
def test_lightgbm_tvm_classifier_large_dataset(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
for tree_implementation in ["gemm", "tree_trav", "perf_tree_trav"]:
|
||||
size = 200000
|
||||
X = np.random.rand(size, 28)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(2, size=size)
|
||||
model = lgb.LGBMClassifier(n_estimators=100, max_depth=3)
|
||||
model.fit(X, y)
|
||||
|
||||
# Create TVM model.
|
||||
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={"tree_implementation": tree_implementation})
|
||||
|
||||
# Check results.
|
||||
np.testing.assert_allclose(tvm_model.predict(X), model.predict(X))
|
||||
np.testing.assert_allclose(tvm_model.predict_proba(X), model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -6,7 +6,13 @@ import warnings
|
|||
|
||||
import numpy as np
|
||||
|
||||
from hummingbird.ml._utils import lightgbm_installed, xgboost_installed, onnx_runtime_installed, onnx_ml_tools_installed
|
||||
from hummingbird.ml._utils import (
|
||||
lightgbm_installed,
|
||||
xgboost_installed,
|
||||
onnx_runtime_installed,
|
||||
onnx_ml_tools_installed,
|
||||
tvm_installed,
|
||||
)
|
||||
|
||||
|
||||
class TestNoExtra(unittest.TestCase):
|
||||
|
@ -39,6 +45,12 @@ class TestNoExtra(unittest.TestCase):
|
|||
warnings.filterwarnings("ignore")
|
||||
assert not onnx_ml_tools_installed()
|
||||
|
||||
# Test no TVM returns false on tvm_installed()
|
||||
@unittest.skipIf(onnx_ml_tools_installed(), reason="Test when TVM is not installed")
|
||||
def test_tvm_installed_false(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
assert not tvm_installed()
|
||||
|
||||
# Test that we can import the converter successfully without installing [extra]
|
||||
def test_import_convert_no_extra(self):
|
||||
try:
|
||||
|
|
|
@ -51,6 +51,33 @@ class TestONNXOneHotEncoder(unittest.TestCase):
|
|||
# Check that predicted values match
|
||||
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
|
||||
|
||||
# Test OneHotEncoder with 2 inputs
|
||||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||
)
|
||||
def test_one_hot_encoder_onnx2(self, rtol=1e-06, atol=1e-06):
|
||||
model = OneHotEncoder()
|
||||
X = np.array([[1, 2, 3], [2, 1, 3]], dtype=np.int32)
|
||||
model.fit(X)
|
||||
|
||||
# Create ONNX-ML model
|
||||
onnx_ml_model = convert_sklearn(model, initial_types=[("int_input", IntTensorType_onnx(X.shape))])
|
||||
|
||||
# Create ONNX model by calling converter
|
||||
onnx_model = convert(onnx_ml_model, "onnx", X)
|
||||
|
||||
# Get the predictions for the ONNX-ML model
|
||||
session = ort.InferenceSession(onnx_ml_model.SerializeToString())
|
||||
output_names = [session.get_outputs()[i].name for i in range(len(session.get_outputs()))]
|
||||
inputs = {session.get_inputs()[0].name: X}
|
||||
onnx_ml_pred = session.run(output_names, inputs)
|
||||
|
||||
# Get the predictions for the ONNX model
|
||||
onnx_pred = onnx_model.transform(X)
|
||||
|
||||
# Check that predicted values match
|
||||
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
|
||||
|
||||
# Test OneHotEncoder with int64
|
||||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||
|
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
from sklearn.preprocessing import MaxAbsScaler, MinMaxScaler, StandardScaler, RobustScaler
|
||||
import torch
|
||||
|
||||
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed, lightgbm_installed
|
||||
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed
|
||||
from hummingbird.ml import convert
|
||||
|
||||
if onnx_runtime_installed():
|
||||
|
|
|
@ -10,6 +10,8 @@ from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
|||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml.exceptions import MissingConverter
|
||||
from hummingbird.ml._utils import tvm_installed
|
||||
from hummingbird.ml import constants
|
||||
from tree_utils import dt_implementation_map
|
||||
|
||||
|
||||
|
@ -26,7 +28,9 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
for extra_config_param in ["tree_trav", "perf_tree_trav", "gemm"]:
|
||||
model.fit(X, y)
|
||||
|
||||
torch_model = hummingbird.ml.convert(model, "torch", extra_config={"tree_implementation": extra_config_param})
|
||||
torch_model = hummingbird.ml.convert(
|
||||
model, "torch", extra_config={constants.TREE_IMPLEMENTATION: extra_config_param}
|
||||
)
|
||||
self.assertIsNotNone(torch_model)
|
||||
self.assertTrue(
|
||||
str(type(list(torch_model.model._operator_map.values())[0])) == dt_implementation_map[extra_config_param]
|
||||
|
@ -56,19 +60,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
# Random forest gemm classifier
|
||||
def test_random_forest_gemm_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 2, extra_config={"tree_implementation": "gemm"}, n_estimators=10
|
||||
RandomForestClassifier, 2, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest tree_trav classifier
|
||||
def test_random_forest_tree_trav_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 2, extra_config={"tree_implementation": "tree_trav"}, n_estimators=10
|
||||
RandomForestClassifier, 2, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest perf_tree_trav classifier
|
||||
def test_random_forest_perf_tree_trav_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 2, extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10
|
||||
RandomForestClassifier, 2, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest multi classifier
|
||||
|
@ -78,37 +82,45 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
# Random forest gemm multi classifier
|
||||
def test_random_forest_gemm_multi_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 3, extra_config={"tree_implementation": "gemm"}, n_estimators=10
|
||||
RandomForestClassifier, 3, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest tree_trav multi classifier
|
||||
def test_random_forest_tree_trav_multi_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 3, extra_config={"tree_implementation": "tree_trav"}, n_estimators=10
|
||||
RandomForestClassifier, 3, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest perf_tree_trav multi classifier
|
||||
def test_random_forest_perf_tree_trav_multi_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 3, extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10
|
||||
RandomForestClassifier, 3, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest gemm classifier shifted classes
|
||||
def test_random_forest_gemm_classifier_shifted_labels_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 3, labels_shift=2, extra_config={"tree_implementation": "gemm"}, n_estimators=10
|
||||
RandomForestClassifier, 3, labels_shift=2, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest tree_trav classifier shifted classes
|
||||
def test_random_forest_tree_trav_classifier_shifted_labels_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 3, labels_shift=2, extra_config={"tree_implementation": "tree_trav"}, n_estimators=10
|
||||
RandomForestClassifier,
|
||||
3,
|
||||
labels_shift=2,
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest perf_tree_trav classifier shifted classes
|
||||
def test_random_forest_perf_tree_trav_classifier_shifted_labels_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 3, labels_shift=2, extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10
|
||||
RandomForestClassifier,
|
||||
3,
|
||||
labels_shift=2,
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Used for regression tests
|
||||
|
@ -133,19 +145,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
# Random forest gemm regressor
|
||||
def test_random_forest_gemm_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
RandomForestRegressor, 1000, extra_config={"tree_implementation": "gemm"}, n_estimators=10
|
||||
RandomForestRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest tree_trav regressor
|
||||
def test_random_forest_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
RandomForestRegressor, 1000, extra_config={"tree_implementation": "tree_trav"}, n_estimators=10
|
||||
RandomForestRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest perf_tree_trav regressor
|
||||
def test_random_forest_perf_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
RandomForestRegressor, 1000, extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10
|
||||
RandomForestRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Extra trees regressor
|
||||
|
@ -155,19 +167,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
# Extra trees gemm regressor
|
||||
def test_extra_trees_gemm_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
ExtraTreesRegressor, 1000, extra_config={"tree_implementation": "gemm"}, n_estimators=10
|
||||
ExtraTreesRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Extra trees tree_trav regressor
|
||||
def test_extra_trees_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
ExtraTreesRegressor, 1000, extra_config={"tree_implementation": "tree_trav"}, n_estimators=10
|
||||
ExtraTreesRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Extra trees perf_tree_trav regressor
|
||||
def test_extra_trees_perf_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
ExtraTreesRegressor, 1000, extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10
|
||||
ExtraTreesRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Decision tree regressor
|
||||
|
@ -176,15 +188,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
|
||||
# Decision tree gemm regressor
|
||||
def test_decision_tree_gemm_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, extra_config={"tree_implementation": "gemm"})
|
||||
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "gemm"})
|
||||
|
||||
# Decision tree tree_trav regressor
|
||||
def test_decision_tree_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, extra_config={"tree_implementation": "tree_trav"})
|
||||
self._run_tree_regressor_converter(
|
||||
DecisionTreeRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}
|
||||
)
|
||||
|
||||
# Decision tree perf_tree_trav regressor
|
||||
def test_decision_tree_perf_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, extra_config={"tree_implementation": "perf_tree_trav"})
|
||||
self._run_tree_regressor_converter(
|
||||
DecisionTreeRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}
|
||||
)
|
||||
|
||||
# Decision tree classifier
|
||||
def test_decision_tree_classifier_converter(self):
|
||||
|
@ -208,15 +224,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
|
||||
# Small tree gemm implementation
|
||||
def test_random_forest_gemm_classifier_single_node_tree_converter(self):
|
||||
self._run_random_forest_classifier_single_node_tree_converter(extra_config={"tree_implementation": "gemm"})
|
||||
self._run_random_forest_classifier_single_node_tree_converter(extra_config={constants.TREE_IMPLEMENTATION: "gemm"})
|
||||
|
||||
# Small tree tree_trav implementation
|
||||
def test_random_forest_tree_trav_classifier_single_node_tree_converter(self):
|
||||
self._run_random_forest_classifier_single_node_tree_converter(extra_config={"tree_implementation": "tree_trav"})
|
||||
self._run_random_forest_classifier_single_node_tree_converter(
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}
|
||||
)
|
||||
|
||||
# Small tree perf_tree_trav implementation
|
||||
def test_random_forest_perf_tree_trav_classifier_single_node_tree_converter(self):
|
||||
self._run_random_forest_classifier_single_node_tree_converter(extra_config={"tree_implementation": "perf_tree_trav"})
|
||||
self._run_random_forest_classifier_single_node_tree_converter(
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}
|
||||
)
|
||||
|
||||
# Float 64 classification test helper
|
||||
def _run_float64_tree_classification_converter(self, model_type, num_classes, extra_config={}, labels_shift=0, **kwargs):
|
||||
|
@ -287,7 +307,7 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
y = np.random.randint(3, size=100)
|
||||
model = RandomForestClassifier(n_estimators=10).fit(X, y)
|
||||
self.assertRaises(
|
||||
MissingConverter, hummingbird.ml.convert, model, "torch", extra_config={"tree_implementation": "nonsense"}
|
||||
MissingConverter, hummingbird.ml.convert, model, "torch", extra_config={constants.TREE_IMPLEMENTATION: "nonsense"}
|
||||
)
|
||||
|
||||
# Test trees with TorchScript backend
|
||||
|
@ -298,19 +318,23 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
# Random forest gemm classifier
|
||||
def test_random_forest_ts_gemm_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 2, "torch.jit", extra_config={"tree_implementation": "gemm"}, n_estimators=10
|
||||
RandomForestClassifier, 2, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest tree_trav classifier
|
||||
def test_random_forest_ts_tree_trav_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 2, "torch.jit", extra_config={"tree_implementation": "tree_trav"}, n_estimators=10
|
||||
RandomForestClassifier, 2, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest perf_tree_trav classifier
|
||||
def test_random_forest_ts_perf_tree_trav_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 2, "torch.jit", extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10
|
||||
RandomForestClassifier,
|
||||
2,
|
||||
"torch.jit",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest multi classifier
|
||||
|
@ -320,19 +344,23 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
# Random forest gemm multi classifier
|
||||
def test_random_forest_ts_gemm_multi_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 3, "torch.jit", extra_config={"tree_implementation": "gemm"}, n_estimators=10
|
||||
RandomForestClassifier, 3, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest tree_trav multi classifier
|
||||
def test_random_forest_ts_tree_trav_multi_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 3, "torch.jit", extra_config={"tree_implementation": "tree_trav"}, n_estimators=10
|
||||
RandomForestClassifier, 3, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest perf_tree_trav multi classifier
|
||||
def test_random_forest_ts_perf_tree_trav_multi_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier, 3, "torch.jit", extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10
|
||||
RandomForestClassifier,
|
||||
3,
|
||||
"torch.jit",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest gemm classifier shifted classes
|
||||
|
@ -342,7 +370,7 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
3,
|
||||
"torch.jit",
|
||||
labels_shift=2,
|
||||
extra_config={"tree_implementation": "gemm"},
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "gemm"},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
|
@ -353,7 +381,7 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
3,
|
||||
"torch.jit",
|
||||
labels_shift=2,
|
||||
extra_config={"tree_implementation": "tree_trav"},
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
|
@ -364,7 +392,7 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
3,
|
||||
"torch.jit",
|
||||
labels_shift=2,
|
||||
extra_config={"tree_implementation": "perf_tree_trav"},
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
|
@ -375,19 +403,27 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
# Random forest gemm regressor
|
||||
def test_random_forest_ts_gemm_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
RandomForestRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "gemm"}, n_estimators=10
|
||||
RandomForestRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Random forest tree_trav regressor
|
||||
def test_random_forest_ts_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
RandomForestRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "tree_trav"}, n_estimators=10
|
||||
RandomForestRegressor,
|
||||
1000,
|
||||
"torch.jit",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest perf_tree_trav regressor
|
||||
def test_random_forest_ts_perf_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
RandomForestRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10
|
||||
RandomForestRegressor,
|
||||
1000,
|
||||
"torch.jit",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Extra trees regressor
|
||||
|
@ -397,19 +433,23 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
# Extra trees gemm regressor
|
||||
def test_extra_trees_ts_gemm_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
ExtraTreesRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "gemm"}, n_estimators=10
|
||||
ExtraTreesRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Extra trees tree_trav regressor
|
||||
def test_extra_trees_ts_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
ExtraTreesRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "tree_trav"}, n_estimators=10
|
||||
ExtraTreesRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
||||
)
|
||||
|
||||
# Extra trees perf_tree_trav regressor
|
||||
def test_extra_trees_ts_perf_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
ExtraTreesRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "perf_tree_trav"}, n_estimators=10
|
||||
ExtraTreesRegressor,
|
||||
1000,
|
||||
"torch.jit",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Decision tree regressor
|
||||
|
@ -419,19 +459,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
# Decision tree gemm regressor
|
||||
def test_decision_tree_ts_gemm_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
DecisionTreeRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "gemm"}
|
||||
DecisionTreeRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}
|
||||
)
|
||||
|
||||
# Decision tree tree_trav regressor
|
||||
def test_decision_tree_ts_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
DecisionTreeRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "tree_trav"}
|
||||
DecisionTreeRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}
|
||||
)
|
||||
|
||||
# Decision tree perf_tree_trav regressor
|
||||
def test_decision_tree_ts_perf_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
DecisionTreeRegressor, 1000, "torch.jit", extra_config={"tree_implementation": "perf_tree_trav"}
|
||||
DecisionTreeRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}
|
||||
)
|
||||
|
||||
# Decision tree classifier
|
||||
|
@ -444,6 +484,224 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
def test_extra_trees_ts_classifier_converter(self):
|
||||
self._run_tree_classification_converter(ExtraTreesClassifier, 3, "torch.jit", n_estimators=10)
|
||||
|
||||
# Test trees with TVM backend
|
||||
# Random forest gemm classifier
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_random_forest_tvm_gemm_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier,
|
||||
2,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest tree_trav classifier
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_random_forest_tvm_tree_trav_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier,
|
||||
2,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest perf_tree_trav classifier
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_random_forest_tvm_perf_tree_trav_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier,
|
||||
2,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest gemm multi classifier
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_random_forest_tvm_gemm_multi_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier,
|
||||
3,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest tree_trav multi classifier
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_random_forest_tvm_tree_trav_multi_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier,
|
||||
3,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest perf_tree_trav multi classifier
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_random_forest_tvm_perf_tree_trav_multi_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier,
|
||||
3,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest gemm classifier shifted classes
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_random_forest_tvm_gemm_classifier_shifted_labels_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier,
|
||||
3,
|
||||
"tvm",
|
||||
labels_shift=2,
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest tree_trav classifier shifted classes
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_random_forest_tvm_tree_trav_classifier_shifted_labels_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier,
|
||||
3,
|
||||
"tvm",
|
||||
labels_shift=2,
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest perf_tree_trav classifier shifted classes
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_random_forest_tvm_perf_tree_trav_classifier_shifted_labels_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
RandomForestClassifier,
|
||||
3,
|
||||
"tvm",
|
||||
labels_shift=2,
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 10},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest gemm regressor
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_random_forest_tvm_gemm_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
RandomForestRegressor,
|
||||
1000,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest tree_trav regressor
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_random_forest_tvm_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
RandomForestRegressor,
|
||||
1000,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Random forest perf_tree_trav regressor
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_random_forest_tvm_perf_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
RandomForestRegressor,
|
||||
1000,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 10},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Extra trees gemm regressor
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_extra_trees_tvm_gemm_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
ExtraTreesRegressor,
|
||||
1000,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Extra trees tree_trav regressor
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_extra_trees_tvm_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
ExtraTreesRegressor,
|
||||
1000,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Extra trees perf_tree_trav regressor
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_extra_trees_tvm_perf_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
ExtraTreesRegressor,
|
||||
1000,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 10},
|
||||
n_estimators=10,
|
||||
)
|
||||
|
||||
# Decision tree regressor
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_decision_tree_tvm_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, "tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
# Decision tree gemm regressor
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_decision_tree_tvm_gemm_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
DecisionTreeRegressor,
|
||||
1000,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
)
|
||||
|
||||
# Decision tree tree_trav regressor
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_decision_tree_tvm_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
DecisionTreeRegressor,
|
||||
1000,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
||||
)
|
||||
|
||||
# Decision tree perf_tree_trav regressor
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_decision_tree_tvm_perf_tree_trav_regressor_converter(self):
|
||||
self._run_tree_regressor_converter(
|
||||
DecisionTreeRegressor,
|
||||
1000,
|
||||
"tvm",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 10},
|
||||
)
|
||||
|
||||
# Decision tree classifier
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_decision_tree_tvm_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
DecisionTreeClassifier, 3, "tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30}
|
||||
)
|
||||
|
||||
# Extra trees classifier
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_extra_trees_tvm_classifier_converter(self):
|
||||
self._run_tree_classification_converter(
|
||||
ExtraTreesClassifier, 3, "tvm", n_estimators=10, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -9,8 +9,9 @@ import torch
|
|||
from sklearn.ensemble import IsolationForest
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml import constants
|
||||
from hummingbird.ml._utils import onnx_runtime_installed, tvm_installed
|
||||
from tree_utils import iforest_implementation_map
|
||||
from hummingbird.ml._utils import onnx_runtime_installed
|
||||
|
||||
|
||||
class TestIsolationForestConverter(unittest.TestCase):
|
||||
|
@ -104,6 +105,23 @@ class TestIsolationForestConverter(unittest.TestCase):
|
|||
np.testing.assert_allclose(model.score_samples(X), onnx_model.score_samples(X), rtol=1e-06, atol=1e-06)
|
||||
np.testing.assert_array_equal(model.predict(X), onnx_model.predict(X))
|
||||
|
||||
# Test TVM backend.
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM test requires TVM")
|
||||
def test_isolation_forest_tvm_converter(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
for max_samples in [2 ** 1, 2 ** 3, 2 ** 8, 2 ** 10, 2 ** 12]:
|
||||
model = IsolationForest(n_estimators=10, max_samples=max_samples)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
model.fit(X)
|
||||
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
self.assertIsNotNone(hb_model)
|
||||
np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06)
|
||||
np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06)
|
||||
np.testing.assert_array_equal(model.predict(X), hb_model.predict(X))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -10,6 +10,8 @@ from sklearn.linear_model import LinearRegression, LogisticRegression, SGDClassi
|
|||
from sklearn import datasets
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml._utils import tvm_installed
|
||||
from hummingbird.ml import constants
|
||||
|
||||
|
||||
class TestSklearnLinearClassifiers(unittest.TestCase):
|
||||
|
@ -224,6 +226,43 @@ class TestSklearnLinearClassifiers(unittest.TestCase):
|
|||
np.testing.assert_allclose(model.predict(X), ts_model.predict(X), rtol=1e-6, atol=1e-6)
|
||||
np.testing.assert_allclose(model.predict_proba(X), ts_model.predict_proba(X), rtol=1e-6, atol=1e-6)
|
||||
|
||||
# Test TVM backends.
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_sgd_classifier_tvm(self):
|
||||
|
||||
model = SGDClassifier(loss="log")
|
||||
|
||||
np.random.seed(0)
|
||||
num_classes = 3
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
tvm_model = hummingbird.ml.convert(model, "tvm", X)
|
||||
self.assertTrue(tvm_model is not None)
|
||||
np.testing.assert_allclose(model.predict(X), tvm_model.predict(X), rtol=1e-6, atol=1e-6)
|
||||
np.testing.assert_allclose(model.predict_proba(X), tvm_model.predict_proba(X), rtol=1e-6, atol=1e-6)
|
||||
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_lr_tvm(self):
|
||||
|
||||
model = LinearRegression()
|
||||
|
||||
np.random.seed(0)
|
||||
num_classes = 1000
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
self.assertTrue(tvm_model is not None)
|
||||
|
||||
np.testing.assert_allclose(model.predict(X), tvm_model.predict(X), rtol=1e-6, atol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -9,12 +9,14 @@ import torch
|
|||
from sklearn.neural_network import MLPClassifier, MLPRegressor
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml import constants
|
||||
from hummingbird.ml._utils import tvm_installed
|
||||
|
||||
|
||||
class TestSklearnMLPClassifier(unittest.TestCase):
|
||||
|
||||
# MLPClassifier test function to be parameterized
|
||||
def _test_mlp_classifer(self, num_classes, activation="relu", labels_shift=0):
|
||||
def _test_mlp_classifer(self, num_classes, activation="relu", labels_shift=0, backend="torch", extra_config={}):
|
||||
model = MLPClassifier(hidden_layer_sizes=(100, 100, 50,), activation=activation)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
|
@ -22,7 +24,7 @@ class TestSklearnMLPClassifier(unittest.TestCase):
|
|||
y = np.random.randint(num_classes, size=100) + labels_shift
|
||||
|
||||
model.fit(X, y)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
torch_model = hummingbird.ml.convert(model, backend, X, extra_config=extra_config)
|
||||
self.assertTrue(torch_model is not None)
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-6, atol=1e-6)
|
||||
|
||||
|
@ -50,6 +52,37 @@ class TestSklearnMLPClassifier(unittest.TestCase):
|
|||
def test_mlp_classifer_multi_identity(self):
|
||||
self._test_mlp_classifer(3, activation="identity")
|
||||
|
||||
# Test TVM backend
|
||||
# MLPClassifier binary
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_mlp_classifer_bi_tvm(self):
|
||||
self._test_mlp_classifer(2, backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
# MLPClassifier multi-class
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_mlp_classifer_multi_tvm(self):
|
||||
self._test_mlp_classifer(3, backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
# MLPClassifier multi-class w/ shifted labels
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_mlp_classifer_multi_shifted_labels_tvm(self):
|
||||
self._test_mlp_classifer(3, labels_shift=3, backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
# MLPClassifier multi-class w/ tanh activation
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_mlp_classifer_multi_logistic_tvm(self):
|
||||
self._test_mlp_classifer(3, activation="tanh", backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
# MLPClassifier multi-class w/ logistic activation
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_mlp_classifer_multi_tanh_tvm(self):
|
||||
self._test_mlp_classifer(3, activation="logistic", backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
# MLPClassifier multi-class w/ identity activation
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_mlp_classifer_multi_identity_tvm(self):
|
||||
self._test_mlp_classifer(3, activation="identity", backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
# MLPRegressor test function to be parameterized
|
||||
def _test_mlp_regressor(self, activation="relu"):
|
||||
model = MLPRegressor(hidden_layer_sizes=(100, 100, 50,), activation=activation)
|
||||
|
|
|
@ -9,13 +9,14 @@ import torch
|
|||
from sklearn.naive_bayes import BernoulliNB, GaussianNB, MultinomialNB
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml._utils import tvm_installed
|
||||
|
||||
|
||||
class TestSklearnNBClassifier(unittest.TestCase):
|
||||
|
||||
# BernoulliNB test function to be parameterized
|
||||
def _test_bernoulinb_classifer(
|
||||
self, num_classes, alpha=1.0, binarize=None, fit_prior=False, class_prior=None, labels_shift=0
|
||||
self, num_classes, alpha=1.0, binarize=None, fit_prior=False, class_prior=None, labels_shift=0, backend="torch"
|
||||
):
|
||||
model = BernoulliNB(alpha=alpha, binarize=binarize, fit_prior=fit_prior, class_prior=class_prior)
|
||||
np.random.seed(0)
|
||||
|
@ -27,7 +28,7 @@ class TestSklearnNBClassifier(unittest.TestCase):
|
|||
y = np.random.randint(num_classes, size=100) + labels_shift
|
||||
|
||||
model.fit(X, y)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
torch_model = hummingbird.ml.convert(model, backend, X)
|
||||
self.assertTrue(torch_model is not None)
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-6, atol=1e-5)
|
||||
|
||||
|
@ -61,8 +62,48 @@ class TestSklearnNBClassifier(unittest.TestCase):
|
|||
def test_bernoulinb_classifer_multi_labels_shift(self):
|
||||
self._test_bernoulinb_classifer(3, labels_shift=3)
|
||||
|
||||
# Test TVM backend
|
||||
# BernoulliNB binary
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_bernoulinb_classifer_bi_tvm(self):
|
||||
self._test_bernoulinb_classifer(2, backend="tvm")
|
||||
|
||||
# BernoulliNB multi-class
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_bernoulinb_classifer_multi_tvm(self):
|
||||
self._test_bernoulinb_classifer(3, backend="tvm")
|
||||
|
||||
# BernoulliNB multi-class w/ modified alpha
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_bernoulinb_classifer_multi_alpha_tvm(self):
|
||||
self._test_bernoulinb_classifer(3, alpha=0.5, backend="tvm")
|
||||
|
||||
# BernoulliNB multi-class w/ binarize
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_bernoulinb_classifer_multi_binarize_tvm(self):
|
||||
self._test_bernoulinb_classifer(3, binarize=0.5, backend="tvm")
|
||||
|
||||
# BernoulliNB multi-class w/ fit prior
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_bernoulinb_classifer_multi_fit_prior_tvm(self):
|
||||
self._test_bernoulinb_classifer(3, fit_prior=True, backend="tvm")
|
||||
|
||||
# BernoulliNB multi-class w/ class prior
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_bernoulinb_classifer_multi_class_prior_tvm(self, backend="tvm"):
|
||||
np.random.seed(0)
|
||||
class_prior = np.random.rand(3)
|
||||
self._test_bernoulinb_classifer(3, class_prior=class_prior)
|
||||
|
||||
# BernoulliNB multi-class w/ labels shift
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_bernoulinb_classifer_multi_labels_shift_tvm(self, backend="tvm"):
|
||||
self._test_bernoulinb_classifer(3, labels_shift=3)
|
||||
|
||||
# MultinomialNB test function to be parameterized
|
||||
def _test_multinomialnb_classifer(self, num_classes, alpha=1.0, fit_prior=False, class_prior=None, labels_shift=0):
|
||||
def _test_multinomialnb_classifer(
|
||||
self, num_classes, alpha=1.0, fit_prior=False, class_prior=None, labels_shift=0, backend="torch"
|
||||
):
|
||||
model = MultinomialNB(alpha=alpha, fit_prior=fit_prior, class_prior=class_prior)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
|
@ -70,7 +111,7 @@ class TestSklearnNBClassifier(unittest.TestCase):
|
|||
y = np.random.randint(num_classes, size=100) + labels_shift
|
||||
|
||||
model.fit(X, y)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
torch_model = hummingbird.ml.convert(model, backend, X)
|
||||
self.assertTrue(torch_model is not None)
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-6, atol=1e-5)
|
||||
|
||||
|
@ -100,8 +141,41 @@ class TestSklearnNBClassifier(unittest.TestCase):
|
|||
def test_multinomialnb_classifer_multi_labels_shift(self):
|
||||
self._test_bernoulinb_classifer(3, labels_shift=3)
|
||||
|
||||
# TVM Backend
|
||||
# MultinomialNB binary
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_multinomialnb_classifer_bi_tvm(self):
|
||||
self._test_bernoulinb_classifer(2, backend="tvm")
|
||||
|
||||
# MultinomialNB multi-class
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_multinomialnb_classifer_multi_tvm(self):
|
||||
self._test_bernoulinb_classifer(3, backend="tvm")
|
||||
|
||||
# MultinomialNB multi-class w/ modified alpha
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_multinomialnb_classifer_multi_alpha_tvm(self):
|
||||
self._test_bernoulinb_classifer(3, alpha=0.5, backend="tvm")
|
||||
|
||||
# MultinomialNB multi-class w/ fir prior
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_multinomialnb_classifer_multi_fit_prior_tvm(self):
|
||||
self._test_bernoulinb_classifer(3, fit_prior=True, backend="tvm")
|
||||
|
||||
# MultinomialNB multi-class w/ class prior
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_multinomialnb_classifer_multi_class_prior_tvm(self):
|
||||
np.random.seed(0)
|
||||
class_prior = np.random.rand(3)
|
||||
self._test_bernoulinb_classifer(3, class_prior=class_prior, backend="tvm")
|
||||
|
||||
# BernoulliNB multi-class w/ labels shift
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_multinomialnb_classifer_multi_labels_shift_tvm(self):
|
||||
self._test_bernoulinb_classifer(3, labels_shift=3, backend="tvm")
|
||||
|
||||
# GaussianNB test function to be parameterized
|
||||
def _test_gaussiannb_classifer(self, num_classes, priors=None, var_smoothing=1e-9, labels_shift=0):
|
||||
def _test_gaussiannb_classifer(self, num_classes, priors=None, var_smoothing=1e-9, labels_shift=0, backend="torch"):
|
||||
model = GaussianNB(priors=priors, var_smoothing=var_smoothing)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
|
@ -109,9 +183,9 @@ class TestSklearnNBClassifier(unittest.TestCase):
|
|||
y = np.random.randint(num_classes, size=100) + labels_shift
|
||||
|
||||
model.fit(X, y)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
torch_model = hummingbird.ml.convert(model, backend, X)
|
||||
self.assertTrue(torch_model is not None)
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-6, atol=1e-5)
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-5, atol=1e-5)
|
||||
|
||||
# GaussianNB binary
|
||||
def test_gaussiannb_classifer_bi(self):
|
||||
|
@ -136,6 +210,35 @@ class TestSklearnNBClassifier(unittest.TestCase):
|
|||
def test_gaussiannb_classifer_multi_labels_shift(self):
|
||||
self._test_gaussiannb_classifer(3, labels_shift=3)
|
||||
|
||||
# TVM Backend
|
||||
# GaussianNB binary
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_gaussiannb_classifer_bi_tvm(self):
|
||||
self._test_gaussiannb_classifer(2, backend="tvm")
|
||||
|
||||
# GaussianNB multi-class
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_gaussiannb_classifer_multi_tvm(self):
|
||||
self._test_gaussiannb_classifer(3, backend="tvm")
|
||||
|
||||
# GaussianNB multi-class w/ class prior
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_gaussiannb_classifer_multi_class_prior_tvm(self):
|
||||
np.random.seed(0)
|
||||
priors = np.random.rand(3)
|
||||
priors = priors / np.sum(priors)
|
||||
self._test_gaussiannb_classifer(3, priors=priors, backend="tvm")
|
||||
|
||||
# GaussianNB multi-class w/ modified var_smoothing
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_gaussiannb_classifer_multi_alpha_tvm(self):
|
||||
self._test_gaussiannb_classifer(3, var_smoothing=1e-2, backend="tvm")
|
||||
|
||||
# GaussianNB multi-class w/ labels shift
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_gaussiannb_classifer_multi_labels_shift_tvm(self):
|
||||
self._test_gaussiannb_classifer(3, labels_shift=3, backend="tvm")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -9,6 +9,8 @@ import torch
|
|||
from sklearn.preprocessing import Normalizer
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml import constants
|
||||
from hummingbird.ml._utils import onnx_runtime_installed, tvm_installed
|
||||
|
||||
|
||||
class TestSklearnNormalizer(unittest.TestCase):
|
||||
|
@ -61,6 +63,46 @@ class TestSklearnNormalizer(unittest.TestCase):
|
|||
model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06,
|
||||
)
|
||||
|
||||
# ONNX backend
|
||||
@unittest.skipIf(not (onnx_runtime_installed()), reason="ONNXML test requires ONNX and ORT")
|
||||
def test_normalizer_converter_onnx(self):
|
||||
# Generate a random 2D array with values in [0, 1000)
|
||||
np.random.seed(0)
|
||||
data = np.random.rand(100, 200) * 1000
|
||||
data = np.array(data, dtype=np.float32)
|
||||
data_tensor = torch.from_numpy(data)
|
||||
|
||||
for norm in ["l1", "l2", "max"]:
|
||||
model = Normalizer(norm=norm)
|
||||
model.fit(data)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "onnx", data)
|
||||
|
||||
self.assertIsNotNone(hb_model)
|
||||
np.testing.assert_allclose(
|
||||
model.transform(data), hb_model.transform(data_tensor)[0], rtol=1e-06, atol=1e-06,
|
||||
)
|
||||
|
||||
# TVM backend
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_normalizer_converter_tvm(self):
|
||||
# Generate a random 2D array with values in [0, 1000)
|
||||
np.random.seed(0)
|
||||
data = np.random.rand(100, 200) * 1000
|
||||
data = np.array(data, dtype=np.float32)
|
||||
data_tensor = torch.from_numpy(data)
|
||||
|
||||
for norm in ["l1", "l2", "max"]:
|
||||
model = Normalizer(norm=norm)
|
||||
model.fit(data)
|
||||
|
||||
torch_model = hummingbird.ml.convert(model, "tvm", data, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(
|
||||
model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -7,39 +7,57 @@ import torch
|
|||
from sklearn.preprocessing import RobustScaler, MaxAbsScaler, MinMaxScaler, StandardScaler
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml._utils import tvm_installed
|
||||
from hummingbird.ml import constants
|
||||
|
||||
|
||||
class TestSklearnScalerConverter(unittest.TestCase):
|
||||
def test_robust_scaler_floats(self):
|
||||
def _test_robust_scaler_floats(self, with_centering, with_scaling, backend="torch"):
|
||||
# Generate a random 2D array with values in [0, 1000)
|
||||
np.random.seed(0)
|
||||
data = np.random.rand(100, 200) * 1000
|
||||
data = np.array(data, dtype=np.float32)
|
||||
data_tensor = torch.from_numpy(data)
|
||||
|
||||
model = RobustScaler(with_centering=False, with_scaling=False)
|
||||
model = RobustScaler(with_centering=with_centering, with_scaling=with_scaling)
|
||||
model.fit(data)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
torch_model = hummingbird.ml.convert(model, backend, data)
|
||||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
|
||||
|
||||
model = RobustScaler(with_centering=False, with_scaling=True)
|
||||
def _test_standard_scaler_floats(self, with_mean, with_std, backend="torch"):
|
||||
# Generate a random 2D array with values in [0, 1000)
|
||||
np.random.seed(0)
|
||||
data = np.random.rand(100, 200) * 1000
|
||||
data = np.array(data, dtype=np.float32)
|
||||
data_tensor = torch.from_numpy(data)
|
||||
|
||||
model = StandardScaler(with_mean=with_mean, with_std=with_std)
|
||||
model.fit(data)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
torch_model = hummingbird.ml.convert(model, backend, data)
|
||||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
|
||||
|
||||
model = RobustScaler(with_centering=True, with_scaling=False)
|
||||
model.fit(data)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
|
||||
def test_robust_scaler_floats_torch_false_false(self):
|
||||
self._test_robust_scaler_floats(False, False)
|
||||
|
||||
model = RobustScaler(with_centering=True, with_scaling=True)
|
||||
model.fit(data)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
|
||||
def test_robust_scaler_floats_torch_true_false(self):
|
||||
self._test_robust_scaler_floats(True, False)
|
||||
|
||||
def test_robust_scaler_floats_torch_falser_true(self):
|
||||
self._test_robust_scaler_floats(False, True)
|
||||
|
||||
def test_robust_scaler_floats_torch_true_true(self):
|
||||
self._test_robust_scaler_floats(True, True)
|
||||
|
||||
def test_standard_scaler_floats_torch_false_false(self):
|
||||
self._test_standard_scaler_floats(False, False)
|
||||
|
||||
def test_standard_scaler_floats_torch_true_false(self):
|
||||
self._test_standard_scaler_floats(True, False)
|
||||
|
||||
def test_standard_scaler_floats_torch_true_true(self):
|
||||
self._test_standard_scaler_floats(True, True)
|
||||
|
||||
def test_max_abs_scaler_floats(self):
|
||||
# Generate a random 2D array with values in [0, 1000)
|
||||
|
@ -69,31 +87,6 @@ class TestSklearnScalerConverter(unittest.TestCase):
|
|||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
|
||||
|
||||
def test_standard_scaler_floats(self):
|
||||
# Generate a random 2D array with values in [0, 1000)
|
||||
np.random.seed(0)
|
||||
data = np.random.rand(100, 200) * 1000
|
||||
data = np.array(data, dtype=np.float32)
|
||||
data_tensor = torch.from_numpy(data)
|
||||
|
||||
model = StandardScaler(with_mean=False, with_std=False)
|
||||
model.fit(data)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
|
||||
|
||||
model = StandardScaler(with_mean=True, with_std=False)
|
||||
model.fit(data)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-04, atol=1e-04)
|
||||
|
||||
model = StandardScaler(with_mean=True, with_std=True)
|
||||
model.fit(data)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Float 64 data tests
|
||||
def test_float64_robust_scaler_floats(self):
|
||||
# Generate a random 2D array with values in [0, 1000)
|
||||
|
@ -107,6 +100,35 @@ class TestSklearnScalerConverter(unittest.TestCase):
|
|||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.transform(data), torch_model.transform(data_tensor), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Tests with TVM backend
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
|
||||
def test_standard_scaler_floats_tvm_false_false(self):
|
||||
self._test_standard_scaler_floats(False, False, "tvm")
|
||||
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
|
||||
def test_standard_scaler_floats_tvm_true_false(self):
|
||||
self._test_standard_scaler_floats(True, False, "tvm")
|
||||
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
|
||||
def test_standard_scaler_floats_tvm_true_true(self):
|
||||
self._test_standard_scaler_floats(True, True, "tvm")
|
||||
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
|
||||
def test_robust_scaler_floats_tvm_false_false(self):
|
||||
self._test_robust_scaler_floats(False, False, "tvm")
|
||||
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
|
||||
def test_robust_scaler_floats_tvm_true_false(self):
|
||||
self._test_robust_scaler_floats(True, False, "tvm")
|
||||
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
|
||||
def test_robust_scaler_floats_tvm_false_true(self):
|
||||
self._test_robust_scaler_floats(False, True, "tvm")
|
||||
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
|
||||
def test_robust_scaler_floats_tvm_true_true(self):
|
||||
self._test_robust_scaler_floats(True, True, "tvm")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -8,6 +8,8 @@ import torch
|
|||
from sklearn.svm import LinearSVC, SVC, NuSVC
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml import constants
|
||||
from hummingbird.ml._utils import tvm_installed
|
||||
|
||||
|
||||
class TestSklearnSVC(unittest.TestCase):
|
||||
|
@ -36,7 +38,7 @@ class TestSklearnSVC(unittest.TestCase):
|
|||
self._test_linear_svc(3, labels_shift=2)
|
||||
|
||||
# SVC test function to be parameterized
|
||||
def _test_svc(self, num_classes, kernel="rbf", gamma=None, labels_shift=0):
|
||||
def _test_svc(self, num_classes, kernel="rbf", gamma=None, backend="torch", labels_shift=0, extra_config={}):
|
||||
|
||||
if gamma:
|
||||
model = SVC(kernel=kernel, gamma=gamma)
|
||||
|
@ -48,7 +50,7 @@ class TestSklearnSVC(unittest.TestCase):
|
|||
y = np.random.randint(num_classes, size=100) + labels_shift
|
||||
|
||||
model.fit(X, y)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
torch_model = hummingbird.ml.convert(model, backend, X, extra_config=extra_config)
|
||||
|
||||
self.assertTrue(torch_model is not None)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6)
|
||||
|
@ -82,7 +84,7 @@ class TestSklearnSVC(unittest.TestCase):
|
|||
self._test_svc(3, gamma="auto")
|
||||
|
||||
# NuSVC test function to be parameterized
|
||||
def _test_nu_svc(self, num_classes):
|
||||
def _test_nu_svc(self, num_classes, backend="torch", extra_config={}):
|
||||
model = NuSVC()
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
|
@ -90,7 +92,7 @@ class TestSklearnSVC(unittest.TestCase):
|
|||
y = np.random.randint(num_classes, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
torch_model = hummingbird.ml.convert(model, backend, X, extra_config=extra_config)
|
||||
|
||||
self.assertTrue(torch_model is not None)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6)
|
||||
|
@ -126,6 +128,77 @@ class TestSklearnSVC(unittest.TestCase):
|
|||
self.assertTrue(torch_model is not None)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6)
|
||||
|
||||
# Torchscript backend
|
||||
def test_svc_ts(self):
|
||||
self._test_svc(2, backend="torch.jit")
|
||||
|
||||
# SVC linear kernel
|
||||
def test_svc_linear_ts(self):
|
||||
self._test_svc(2, kernel="linear", backend="torch.jit")
|
||||
|
||||
# SVC sigmoid kernel
|
||||
def test_svc_sigmoid_ts(self):
|
||||
self._test_svc(2, kernel="sigmoid", backend="torch.jit")
|
||||
|
||||
# SVC poly kernel
|
||||
def test_svc_poly_ts(self):
|
||||
self._test_svc(2, kernel="poly", backend="torch.jit")
|
||||
|
||||
# NuSVC binary
|
||||
def test_nu_svc_bi_ts(self):
|
||||
self._test_nu_svc(2, backend="torch.jit")
|
||||
|
||||
def test_svc_multi_ts(self):
|
||||
self._test_svc(3, backend="torch.jit")
|
||||
|
||||
# TVM backend.
|
||||
# SVC binary
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_svc_tvm(self):
|
||||
self._test_svc(2, backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
# SVC linear kernel
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_svc_linear_tvm(self):
|
||||
self._test_svc(2, kernel="linear", backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
# SVC sigmoid kernel
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_svc_sigmoid_tvm(self):
|
||||
self._test_svc(2, kernel="sigmoid", backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
# SVC poly kernel
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_svc_poly_tvm(self):
|
||||
self._test_svc(2, kernel="poly", backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
# NuSVC binary
|
||||
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
||||
def test_nu_svc_bi_tvm(self):
|
||||
self._test_nu_svc(2, backend="tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
|
||||
# Commenting SVC multiclass for TVM because we are currently missing an operator for implementing mode
|
||||
# https://github.com/pytorch/pytorch/issues/43867.
|
||||
# # SVC multiclass
|
||||
# def test_svc_multi(self):
|
||||
# self._test_svc(3)
|
||||
|
||||
# # SVC sigmoid kernel
|
||||
# def test_svc_sigmoid(self):
|
||||
# self._test_svc(3, kernel="sigmoid", backend="tvm")
|
||||
|
||||
# # SVC poly kernel
|
||||
# def test_svc_poly(self):
|
||||
# self._test_svc(3, kernel="poly", backend="tvm")
|
||||
|
||||
# # SVC with class labels shifted
|
||||
# def test_svc_shifted(self):
|
||||
# self._test_svc(3, labels_shift=2, backend="tvm")
|
||||
|
||||
# # SVC with different gamma (default=’scale’)
|
||||
# def test_svc_gamma(self):
|
||||
# self._test_svc(3, gamma="auto", backend="tvm")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -7,7 +7,8 @@ import warnings
|
|||
import numpy as np
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml._utils import xgboost_installed
|
||||
from hummingbird.ml._utils import xgboost_installed, tvm_installed
|
||||
from hummingbird.ml import constants
|
||||
from tree_utils import gbdt_implementation_map
|
||||
|
||||
if xgboost_installed():
|
||||
|
@ -102,7 +103,7 @@ class TestXGBoostConverter(unittest.TestCase):
|
|||
|
||||
model.fit(X, y, group=[X.shape[0]])
|
||||
|
||||
torch_model = hummingbird.ml.convert(model, "torch", X[0:1], extra_config=extra_config)
|
||||
torch_model = hummingbird.ml.convert(model, "torch", X, extra_config=extra_config)
|
||||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
|
@ -136,7 +137,7 @@ class TestXGBoostConverter(unittest.TestCase):
|
|||
y = np.random.randint(num_classes, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
torch_model = hummingbird.ml.convert(model, "torch", X[0:1], extra_config=extra_config)
|
||||
torch_model = hummingbird.ml.convert(model, "torch", X, extra_config=extra_config)
|
||||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
|
@ -225,6 +226,7 @@ class TestXGBoostConverter(unittest.TestCase):
|
|||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Torchscript backends.
|
||||
# Test TorchScript backend regression.
|
||||
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
|
||||
def test_xgb_regressor_converter_torchscript(self):
|
||||
|
@ -263,6 +265,47 @@ class TestXGBoostConverter(unittest.TestCase):
|
|||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# TVM backend tests.
|
||||
# TVM backend regression.
|
||||
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
|
||||
def test_xgb_regressor_converter_tvm(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
import torch
|
||||
|
||||
for max_depth in [1, 3, 8, 10, 12]:
|
||||
model = xgb.XGBRegressor(n_estimators=10, max_depth=max_depth)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(1000, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
self.assertIsNotNone(tvm_model)
|
||||
np.testing.assert_allclose(model.predict(X), tvm_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Test TVM backend classification.
|
||||
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
|
||||
def test_xgb_classifier_converter_tvm(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
import torch
|
||||
|
||||
for max_depth in [1, 3, 8, 10, 12]:
|
||||
model = xgb.XGBClassifier(n_estimators=10, max_depth=max_depth)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(2, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
||||
self.assertIsNotNone(tvm_model)
|
||||
np.testing.assert_allclose(model.predict_proba(X), tvm_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче