Add batching to output containers (#323)

This commit is contained in:
Matteo Interlandi 2020-10-21 07:51:44 -07:00 коммит произвёл GitHub
Родитель 554f5382f9
Коммит 5fd2cf4467
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 544 добавлений и 55 удалений

Просмотреть файл

@ -67,6 +67,51 @@ class SklearnContainer(ABC):
def model(self):
return self._model
def _run(self, function, *inputs, reshape=False):
"""
This function either score the full dataset at once or triggers batch inference.
"""
if DataFrame is not None and type(inputs[0]) == DataFrame:
# Split the dataframe into column ndarrays.
inputs = inputs[0]
input_names = list(inputs.columns)
splits = [inputs[input_names[idx]] for idx in range(len(input_names))]
inputs = [df.to_numpy().reshape(-1, 1) for df in splits]
if self._batch_size is None:
return function(*inputs)
else:
return self._run_batch_inference(function, *inputs, reshape=reshape)
def _run_batch_inference(self, function, *inputs, reshape=False):
"""
This function contains the code to run batched inference.
"""
is_tuple = type(inputs) is tuple
if is_tuple:
total_size = inputs[0].shape[0]
else:
total_size = inputs.shape[0]
iterations = total_size // self._batch_size
iterations += 1 if total_size % self._batch_size > 0 else 0
iterations = max(1, iterations)
predictions = []
for i in range(0, iterations):
start = i * self._batch_size
end = min(start + self._batch_size, total_size)
if is_tuple:
batch = tuple([input[start:end, :] for input in inputs])
else:
batch = inputs[start:end, :]
predictions.extend(function(*batch).ravel())
if reshape:
return np.array(predictions).ravel().reshape(total_size, -1)
return np.array(predictions).ravel()
class PyTorchTorchscriptSklearnContainer(SklearnContainer):
"""
@ -91,12 +136,15 @@ class PyTorchSklearnContainerTransformer(PyTorchTorchscriptSklearnContainer):
Container mirroring Sklearn transformers API.
"""
def _transform(self, *inputs):
return self.model.forward(*inputs).cpu().numpy()
def transform(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On data transformers it returns transformed output data
"""
return self.model.forward(*inputs).cpu().numpy()
return self._run(self._transform, *inputs, reshape=True)
class PyTorchSklearnContainerRegression(PyTorchTorchscriptSklearnContainer):
@ -114,6 +162,14 @@ class PyTorchSklearnContainerRegression(PyTorchTorchscriptSklearnContainer):
self._is_regression = is_regression
self._is_anomaly_detection = is_anomaly_detection
def _predict(self, *inputs):
if self._is_regression:
return self.model.forward(*inputs).cpu().numpy().ravel()
elif self._is_anomaly_detection:
return self.model.forward(*inputs)[0].cpu().numpy().ravel()
else:
return self.model.forward(*inputs)[0].cpu().numpy().ravel()
def predict(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
@ -121,12 +177,7 @@ class PyTorchSklearnContainerRegression(PyTorchTorchscriptSklearnContainer):
On classification tasks returns the predicted class labels for the input data.
On anomaly detection (e.g. isolation forest) returns the predicted classes (-1 or 1).
"""
if self._is_regression:
return self.model.forward(*inputs).cpu().numpy().flatten()
elif self._is_anomaly_detection:
return self.model.forward(*inputs)[0].cpu().numpy().flatten()
else:
return self.model.forward(*inputs)[0].cpu().numpy()
return self._run(self._predict, *inputs)
class PyTorchSklearnContainerClassification(PyTorchSklearnContainerRegression):
@ -139,12 +190,15 @@ class PyTorchSklearnContainerClassification(PyTorchSklearnContainerRegression):
model, n_threads, batch_size, is_regression=False, extra_config=extra_config
)
def _predict_proba(self, *input):
return self.model.forward(*input)[1].cpu().numpy()
def predict_proba(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On classification tasks returns the probability estimates.
"""
return self.model.forward(*inputs)[1].cpu().numpy()
return self._run(self._predict_proba, *inputs, reshape=True)
class PyTorchSklearnContainerAnomalyDetection(PyTorchSklearnContainerRegression):
@ -157,12 +211,16 @@ class PyTorchSklearnContainerAnomalyDetection(PyTorchSklearnContainerRegression)
model, n_threads, batch_size, is_regression=False, is_anomaly_detection=True, extra_config=extra_config
)
def _decision_function(self, *inputs):
return self.model.forward(*inputs)[1].cpu().numpy().ravel()
def decision_function(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On anomaly detection (e.g. isolation forest) returns the decision function scores.
"""
scores = self.model.forward(*inputs)[1].cpu().numpy().flatten()
scores = self._run(self._decision_function, *inputs)
# Backward compatibility for sklearn <= 0.21
if constants.IFOREST_THRESHOLD in self._extra_config:
scores += self._extra_config[constants.IFOREST_THRESHOLD]
@ -180,26 +238,18 @@ class PyTorchSklearnContainerAnomalyDetection(PyTorchSklearnContainerRegression)
def _torchscript_wrapper(device, function, *inputs):
"""
This function contains the code to enable predictions over torchscript models.
It is used to wrap pytorch container functions.
It is used to translates inputs in the proper torch format.
"""
inputs = [*inputs]
with torch.no_grad():
if type(inputs) == DataFrame and DataFrame is not None:
# Split the dataframe into column ndarrays.
inputs = inputs[0]
input_names = list(inputs.columns)
splits = [inputs[input_names[idx]] for idx in range(len(input_names))]
splits = [df.to_numpy().reshape(-1, 1) for df in splits]
inputs = tuple(splits)
# Maps data inputs to the expected type and device.
for i in range(len(inputs)):
if type(inputs[i]) is np.ndarray:
inputs[i] = torch.from_numpy(inputs[i]).float()
elif type(inputs[i]) is not torch.Tensor:
raise RuntimeError("Inputer tensor {} of not supported type {}".format(i, type(inputs[i])))
if device != "cpu" and device is not None:
if device.type != "cpu" and device is not None:
inputs[i] = inputs[i].to(device)
return function(*inputs)
@ -211,9 +261,10 @@ class TorchScriptSklearnContainerTransformer(PyTorchSklearnContainerTransformer)
def transform(self, *inputs):
device = _get_device(self.model)
f = super(TorchScriptSklearnContainerTransformer, self).transform
f = super(TorchScriptSklearnContainerTransformer, self)._transform
f_wrapped = lambda x: _torchscript_wrapper(device, f, x) # noqa: E731
return _torchscript_wrapper(device, f, *inputs)
return self._run(f_wrapped, *inputs, reshape=True)
class TorchScriptSklearnContainerRegression(PyTorchSklearnContainerRegression):
@ -223,9 +274,10 @@ class TorchScriptSklearnContainerRegression(PyTorchSklearnContainerRegression):
def predict(self, *inputs):
device = _get_device(self.model)
f = super(TorchScriptSklearnContainerRegression, self).predict
f = super(TorchScriptSklearnContainerRegression, self)._predict
f_wrapped = lambda x: _torchscript_wrapper(device, f, x) # noqa: E731
return _torchscript_wrapper(device, f, *inputs)
return self._run(f_wrapped, *inputs)
class TorchScriptSklearnContainerClassification(PyTorchSklearnContainerClassification):
@ -235,15 +287,17 @@ class TorchScriptSklearnContainerClassification(PyTorchSklearnContainerClassific
def predict(self, *inputs):
device = _get_device(self.model)
f = super(TorchScriptSklearnContainerClassification, self).predict
f = super(TorchScriptSklearnContainerClassification, self)._predict
f_wrapped = lambda x: _torchscript_wrapper(device, f, x) # noqa: E731
return _torchscript_wrapper(device, f, *inputs)
return self._run(f_wrapped, *inputs)
def predict_proba(self, *inputs):
device = _get_device(self.model)
f = super(TorchScriptSklearnContainerClassification, self).predict_proba
f = super(TorchScriptSklearnContainerClassification, self)._predict_proba
f_wrapped = lambda *x: _torchscript_wrapper(device, f, *x) # noqa: E731
return _torchscript_wrapper(device, f, *inputs)
return self._run(f_wrapped, *inputs, reshape=True)
class TorchScriptSklearnContainerAnomalyDetection(PyTorchSklearnContainerAnomalyDetection):
@ -253,21 +307,24 @@ class TorchScriptSklearnContainerAnomalyDetection(PyTorchSklearnContainerAnomaly
def predict(self, *inputs):
device = _get_device(self.model)
f = super(TorchScriptSklearnContainerAnomalyDetection, self).predict
f = super(TorchScriptSklearnContainerAnomalyDetection, self)._predict
f_wrapped = lambda x: _torchscript_wrapper(device, f, x) # noqa: E731
return _torchscript_wrapper(device, f, *inputs)
return self._run(f_wrapped, *inputs)
def decision_function(self, *inputs):
device = _get_device(self.model)
f = super(TorchScriptSklearnContainerAnomalyDetection, self).decision_function
f = super(TorchScriptSklearnContainerAnomalyDetection, self)._decision_function
f_wrapped = lambda x: _torchscript_wrapper(device, f, x) # noqa: E731
return _torchscript_wrapper(device, f, *inputs)
return self._run(f_wrapped, *inputs)
def score_samples(self, *inputs):
device = _get_device(self.model)
f = super(TorchScriptSklearnContainerAnomalyDetection, self).score_samples
f = self.decision_function
f_wrapped = lambda x: _torchscript_wrapper(device, f, x) # noqa: E731
return _torchscript_wrapper(device, f, *inputs)
return self._run(f_wrapped, *inputs) + self._extra_config[constants.OFFSET]
# ONNX containers.
@ -324,14 +381,17 @@ class ONNXSklearnContainerTransformer(ONNXSklearnContainer):
assert len(self._output_names) == 1
def _transform(self, *inputs):
named_inputs = self._get_named_inputs(inputs)
return np.array(self._session.run(self._output_names, named_inputs))
def transform(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On data transformers it returns transformed output data
"""
named_inputs = self._get_named_inputs(inputs)
return self._session.run(self._output_names, named_inputs)
return self._run(self._transform, *inputs, reshape=True)
class ONNXSklearnContainerRegression(ONNXSklearnContainer):
@ -351,7 +411,7 @@ class ONNXSklearnContainerRegression(ONNXSklearnContainer):
self._is_regression = is_regression
self._is_anomaly_detection = is_anomaly_detection
def predict(self, *inputs):
def _predict(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On regression returns the predicted values.
@ -361,11 +421,18 @@ class ONNXSklearnContainerRegression(ONNXSklearnContainer):
named_inputs = self._get_named_inputs(inputs)
if self._is_regression:
return self._session.run(self._output_names, named_inputs)
return np.array(self._session.run(self._output_names, named_inputs))
elif self._is_anomaly_detection:
return np.array(self._session.run([self._output_names[0]], named_inputs))[0].flatten()
return np.array(self._session.run([self._output_names[0]], named_inputs))[0].ravel()
else:
return self._session.run([self._output_names[0]], named_inputs)[0]
return np.array(self._session.run([self._output_names[0]], named_inputs))[0]
def predict(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On data transformers it returns transformed output data
"""
return self._run(self._predict, *inputs)
class ONNXSklearnContainerClassification(ONNXSklearnContainerRegression):
@ -380,7 +447,7 @@ class ONNXSklearnContainerClassification(ONNXSklearnContainerRegression):
assert len(self._output_names) == 2
def predict_proba(self, *inputs):
def _predict_proba(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On classification tasks returns the probability estimates.
@ -389,6 +456,13 @@ class ONNXSklearnContainerClassification(ONNXSklearnContainerRegression):
return self._session.run([self._output_names[1]], named_inputs)[0]
def predict_proba(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On data transformers it returns transformed output data
"""
return self._run(self._predict_proba, *inputs, reshape=True)
class ONNXSklearnContainerAnomalyDetection(ONNXSklearnContainerRegression):
"""
@ -402,7 +476,7 @@ class ONNXSklearnContainerAnomalyDetection(ONNXSklearnContainerRegression):
assert len(self._output_names) == 2
def decision_function(self, *inputs):
def _decision_function(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On anomaly detection (e.g. isolation forest) returns the decision function scores.
@ -415,6 +489,13 @@ class ONNXSklearnContainerAnomalyDetection(ONNXSklearnContainerRegression):
scores += self._extra_config[constants.IFOREST_THRESHOLD]
return scores
def decision_function(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.
On data transformers it returns transformed output data
"""
return self._run(self._decision_function, *inputs)
def score_samples(self, *inputs):
"""
Utility functions used to emulate the behavior of the Sklearn API.

Просмотреть файл

@ -32,7 +32,7 @@ class LinearModel(BaseOperator, torch.nn.Module):
self.binary_classification = True
def forward(self, x):
output = torch.addmm(self.intercepts, x, self.coefficients)
output = torch.addmm(self.intercepts, x.float(), self.coefficients)
if self.multi_class == "multinomial":
output = torch.softmax(output, dim=1)
elif self.regression:

Просмотреть файл

@ -21,7 +21,13 @@ with open(README) as f:
if start_pos >= 0:
long_description = long_description[start_pos:]
install_requires = ["numpy>=1.15", "onnxconverter-common>=1.6.0", "scikit-learn>=0.21.3", "torch>=1.4.*", "psutil"]
install_requires = [
"numpy>=1.15",
"onnxconverter-common>=1.6.0",
"scikit-learn>=0.21.3",
"torch>=1.4.*",
"psutil",
]
onnx_requires = [
"onnxruntime>=1.0.0",
"onnxmltools>=1.6.0",

Просмотреть файл

@ -8,18 +8,23 @@ import warnings
import sys
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier
from onnxconverter_common.data_types import FloatTensorType
from sklearn import datasets
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor, IsolationForest
from sklearn.preprocessing import StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
import torch
import hummingbird.ml
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed, lightgbm_installed
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed, pandas_installed, lightgbm_installed
from hummingbird.ml import constants
if lightgbm_installed():
import lightgbm as lgb
if onnx_ml_tools_installed():
from onnxmltools.convert import convert_lightgbm
from onnxmltools.convert import convert_sklearn, convert_lightgbm
class TestExtraConf(unittest.TestCase):
@ -81,7 +86,12 @@ class TestExtraConf(unittest.TestCase):
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "onnx", X)
# Create ONNX-ML model
onnx_ml_model = convert_sklearn(
model, initial_types=[("input", FloatTensorType([X.shape[0], X.shape[1]]))], target_opset=9
)
hb_model = hummingbird.ml.convert(onnx_ml_model, "onnx", X)
self.assertIsNotNone(hb_model)
self.assertTrue(hb_model._session.get_session_options().intra_op_num_threads == psutil.cpu_count(logical=False))
@ -109,6 +119,357 @@ class TestExtraConf(unittest.TestCase):
self.assertTrue(hb_model._session.get_session_options().intra_op_num_threads == 1)
self.assertTrue(hb_model._session.get_session_options().inter_op_num_threads == 1)
# Test pytorch regressor with batching.
def test_torch_regression_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingRegressor(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "torch", extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
# Test pytorch classifier with batching.
def test_torch_classification_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "torch", extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
# Test pytorch classifier with batching.
def test_torch_iforest_batch(self):
warnings.filterwarnings("ignore")
num_classes = 2
model = IsolationForest(n_estimators=10, max_samples=2)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "torch", extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06)
# Test pytorch regressor with batching and uneven rows.
def test_torch_batch_regression_uneven(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingRegressor(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(105, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=105)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "torch", extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
# Test pytorch classification with batching and uneven rows.
def test_torch_batch_classification_uneven(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(105, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=105)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "torch", extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
# Test pytorch transform with batching and uneven rows.
def test_torch_batch_transform(self):
warnings.filterwarnings("ignore")
model = StandardScaler(with_mean=True, with_std=True)
np.random.seed(0)
X = np.random.rand(105, 200)
X = np.array(X, dtype=np.float32)
model.fit(X)
hb_model = hummingbird.ml.convert(model, "torch", extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.transform(X), hb_model.transform(X), rtol=1e-06, atol=1e-06)
# Test torchscript regression with batching.
def test_torchscript_regression_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingRegressor(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(103, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=103)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "torch.jit", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
# Test torchscript classification with batching.
def test_torchscript_classification_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(103, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=103)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "torch.jit", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
# Test torchscript iforest with batching.
def test_torchscript_iforest_batch(self):
warnings.filterwarnings("ignore")
num_classes = 2
model = IsolationForest(n_estimators=10, max_samples=2)
np.random.seed(0)
X = np.random.rand(103, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=103)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "torch.jit", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06)
# Test torchscript transform with batching and uneven rows.
def test_torchscript_batch_transform(self):
warnings.filterwarnings("ignore")
model = StandardScaler(with_mean=True, with_std=True)
np.random.seed(0)
X = np.random.rand(101, 200)
X = np.array(X, dtype=np.float32)
model.fit(X)
hb_model = hummingbird.ml.convert(model, "torch.jit", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.transform(X), hb_model.transform(X), rtol=1e-06, atol=1e-06)
# Test onnx transform with batching and uneven rows.
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
)
def test_onnx_batch_transform(self):
warnings.filterwarnings("ignore")
model = StandardScaler(with_mean=True, with_std=True)
np.random.seed(0)
X = np.random.rand(101, 200)
X = np.array(X, dtype=np.float32)
model.fit(X)
hb_model = hummingbird.ml.convert(model, "onnx", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.transform(X), hb_model.transform(X), rtol=1e-06, atol=1e-06)
# Test onnx regression with batching.
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
)
def test_onnx_regression_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingRegressor(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(103, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=103)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "onnx", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
# Test onnx classification with batching.
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
)
def test_onnx_classification_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(103, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=103)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "onnx", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
# Test onnx iforest with batching.
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
)
def test_onnx_iforest_batch(self):
warnings.filterwarnings("ignore")
num_classes = 2
model = IsolationForest(n_estimators=10, max_samples=2)
np.random.seed(0)
X = np.random.rand(103, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=103)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "onnx", X, extra_config={constants.BATCH_SIZE: 10})
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict(X), hb_model.predict(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06)
# Test batch with pandas.
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
def test_pandas_batch(self):
import pandas
max_depth = 10
iris = datasets.load_iris()
X = iris.data[:, :3]
y = iris.target
columns = ["vA", "vB", "vC"]
X_train = pandas.DataFrame(X, columns=columns)
pipeline = Pipeline(
steps=[
("preprocessor", ColumnTransformer(transformers=[], remainder="passthrough",)),
("classifier", GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)),
]
)
pipeline.fit(X_train, y)
torch_model = hummingbird.ml.convert(pipeline, "torch", X_train, extra_config={constants.BATCH_SIZE: 10})
self.assertTrue(torch_model is not None)
np.testing.assert_allclose(
pipeline.predict_proba(X_train), torch_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
)
# Test batch with pandas.
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
def test_pandas_batch_ts(self):
import pandas
max_depth = 10
iris = datasets.load_iris()
X = iris.data[:, :3]
y = iris.target
columns = ["vA", "vB", "vC"]
X_train = pandas.DataFrame(X, columns=columns)
pipeline = Pipeline(
steps=[
("preprocessor", ColumnTransformer(transformers=[], remainder="passthrough",)),
("classifier", GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)),
]
)
pipeline.fit(X_train, y)
torch_model = hummingbird.ml.convert(pipeline, "torch.jit", X_train, extra_config={constants.BATCH_SIZE: 10})
self.assertTrue(torch_model is not None)
np.testing.assert_allclose(
pipeline.predict_proba(X_train), torch_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
)
# Test batch with pandas.
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
)
def test_pandas_batch_onnx(self):
import pandas
max_depth = 10
iris = datasets.load_iris()
X = iris.data[:, :3]
y = iris.target
columns = ["vA", "vB", "vC"]
X_train = pandas.DataFrame(X, columns=columns)
pipeline = Pipeline(
steps=[
("preprocessor", ColumnTransformer(transformers=[], remainder="passthrough",)),
("classifier", GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)),
]
)
pipeline.fit(X_train, y)
hb_model = hummingbird.ml.convert(pipeline, "onnx", X_train, extra_config={constants.BATCH_SIZE: 10})
self.assertTrue(hb_model is not None)
np.testing.assert_allclose(
pipeline.predict_proba(X_train), hb_model.predict_proba(X_train), rtol=1e-06, atol=1e-06,
)
# Check converter with model name set as extra_config.
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"

Просмотреть файл

@ -260,6 +260,47 @@ class TestSklearnPipeline(unittest.TestCase):
model.predict_proba(X_test), torch_model.predict_proba(X_test), rtol=1e-06, atol=1e-06,
)
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
def test_pipeline_column_transformer_pandas_ts(self):
iris = datasets.load_iris()
X = np.array(iris.data[:, :3], np.float32) # If we don't use float32 here, with python 3.5 and torch 1.5.1 will fail.
y = iris.target
X_train = pandas.DataFrame(X, columns=["vA", "vB", "vC"])
X_train["vcat"] = X_train["vA"].apply(lambda x: 1 if x > 0.5 else 2)
X_train["vcat2"] = X_train["vB"].apply(lambda x: 3 if x > 0.5 else 4)
y_train = y % 2
numeric_features = [0, 1, 2] # ["vA", "vB", "vC"]
categorical_features = [3, 4] # ["vcat", "vcat2"]
classifier = LogisticRegression(
C=0.01, class_weight=dict(zip([False, True], [0.2, 0.8])), n_jobs=1, max_iter=10, solver="liblinear", tol=1e-3,
)
numeric_transformer = Pipeline(steps=[("scaler", StandardScaler())])
categorical_transformer = Pipeline(steps=[("onehot", OneHotEncoder(sparse=True, handle_unknown="ignore"))])
preprocessor = ColumnTransformer(
transformers=[
("num", numeric_transformer, numeric_features),
("cat", categorical_transformer, categorical_features),
]
)
model = Pipeline(steps=[("preprocessor", preprocessor), ("classifier", classifier)])
model.fit(X_train, y_train)
X_test = X_train[:11]
torch_model = hummingbird.ml.convert(model, "torch.jit", X_test)
self.assertTrue(torch_model is not None)
np.testing.assert_allclose(
model.predict_proba(X_test), torch_model.predict_proba(X_test), rtol=1e-06, atol=1e-06,
)
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
def test_pipeline_column_transformer_weights(self):
iris = datasets.load_iris()
@ -288,7 +329,7 @@ class TestSklearnPipeline(unittest.TestCase):
transformer_weights={"num": 2, "cat": 3},
)
model = Pipeline(steps=[("precprocessor", preprocessor), ("classifier", classifier)])
model = Pipeline(steps=[("preprocessor", preprocessor), ("classifier", classifier)])
model.fit(X_train, y_train)

Просмотреть файл

@ -30,20 +30,20 @@ class TestSparkMLDiscretizers(unittest.TestCase):
@unittest.skipIf(LooseVersion(torch.__version__) < LooseVersion("1.6.0"), reason="Spark-ML test requires torch >= 1.6.0")
def test_quantilediscretizer_converter(self):
iris = load_iris()
features = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
features = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
pd_df = pd.DataFrame(data=np.c_[iris['data'], iris['target']], columns=features + ['target'])
df = sql.createDataFrame(pd_df).select('sepal_length')
pd_df = pd.DataFrame(data=np.c_[iris["data"], iris["target"]], columns=features + ["target"])
df = sql.createDataFrame(pd_df).select("sepal_length")
quantile = QuantileDiscretizer(inputCol='sepal_length', outputCol='sepal_length_bucket', numBuckets=2)
quantile = QuantileDiscretizer(inputCol="sepal_length", outputCol="sepal_length_bucket", numBuckets=2)
model = quantile.fit(df)
test_df = df
torch_model = convert(model, "torch", test_df)
self.assertTrue(torch_model is not None)
spark_output = model.transform(test_df).select('sepal_length_bucket').toPandas()
torch_output_np = torch_model.transform(pd_df)
spark_output = model.transform(test_df).select("sepal_length_bucket").toPandas()
torch_output_np = torch_model.transform(pd_df[["sepal_length"]])
np.testing.assert_allclose(spark_output.to_numpy(), torch_output_np, rtol=1e-06, atol=1e-06)