This commit is contained in:
Matteo Interlandi 2021-03-24 10:53:28 +01:00
Родитель d2d622ab5c 17e6ce0cfd
Коммит bdddf5c2cb
11 изменённых файлов: 1526 добавлений и 65 удалений

Просмотреть файл

@ -91,7 +91,7 @@ In general, Hummingbird syntax is very intuitive and minimal. To run your tradit
```python
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from hummingbird.ml import convert
from hummingbird.ml import convert, load
# Create some random data for binary classification
num_classes = 2
@ -116,7 +116,7 @@ model.predict(X)
model.save('hb_model')
# Load the model back
model = hummingbird.ml.load('hb_model')
model = load('hb_model')
```
# Documentation

Просмотреть файл

@ -21,6 +21,7 @@ from .onnx import onnx_operator # noqa: E402
from .onnx import array_feature_extractor as onnx_afe # noqa: E402, F811
from .onnx import binarizer as onnx_binarizer # noqa: E402, F811
from .onnx import feature_vectorizer # noqa: E402
from .onnx import imputer as onnx_imputer # noqa: E402
from .onnx import label_encoder as onnx_label_encoder # noqa: E402, F811
from .onnx import linear as onnx_linear # noqa: E402, F811
from .onnx import normalizer as onnx_normalizer # noqa: E402, F811

Просмотреть файл

@ -0,0 +1,81 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""
Base classes for Imputers
"""
import torch
import numpy as np
from ._physical_operator import PhysicalOperator
from . import constants
class SimpleImputer(PhysicalOperator, torch.nn.Module):
"""
Class implementing SimpleImputer operators in PyTorch.
"""
def __init__(self, logical_operator, device, statistics=None, missing=None, strategy=None):
super(SimpleImputer, self).__init__(logical_operator)
sklearn_imputer = logical_operator.raw_operator
# Pull out the stats field from either the SKL imputer or args
stats_ = statistics if statistics is not None else sklearn_imputer.statistics_
# Process the stats into an array
stats = [float(stat) for stat in stats_]
missing_values = missing if missing is not None else sklearn_imputer.missing_values
strategy = strategy if strategy is not None else sklearn_imputer.strategy
b_mask = np.logical_not(np.isnan(stats))
i_mask = [i for i in range(len(b_mask)) if b_mask[i]]
self.transformer = True
self.do_mask = strategy == "constant" or all(b_mask)
self.mask = torch.nn.Parameter(torch.LongTensor([] if self.do_mask else i_mask), requires_grad=False)
self.replace_values = torch.nn.Parameter(torch.tensor([stats_], dtype=torch.float32), requires_grad=False)
self.is_nan = True if (missing_values == "NaN" or np.isnan(missing_values)) else False
if not self.is_nan:
self.missing_values = torch.nn.Parameter(torch.tensor([missing_values], dtype=torch.float32), requires_grad=False)
def forward(self, x):
if self.is_nan:
result = torch.where(torch.isnan(x), self.replace_values.expand(x.shape), x)
if self.do_mask:
return result
return torch.index_select(result, 1, self.mask)
else:
return torch.where(torch.eq(x, self.missing_values), self.replace_values.expand(x.shape), x)
class MissingIndicator(PhysicalOperator, torch.nn.Module):
"""
Class implementing Imputer operators in MissingIndicator.
"""
def __init__(self, logical_operator, device):
super(MissingIndicator, self).__init__(logical_operator)
sklearn_missing_indicator = logical_operator.raw_operator
self.transformer = True
self.missing_values = torch.nn.Parameter(
torch.tensor([sklearn_missing_indicator.missing_values], dtype=torch.float32), requires_grad=False
)
self.features = sklearn_missing_indicator.features
self.is_nan = True if (sklearn_missing_indicator.missing_values in ["NaN", None, np.nan]) else False
self.column_indices = torch.nn.Parameter(torch.LongTensor(sklearn_missing_indicator.features_), requires_grad=False)
def forward(self, x):
if self.is_nan:
if self.features == "all":
return torch.isnan(x).float()
else:
return torch.isnan(torch.index_select(x, 1, self.column_indices)).float()
else:
if self.features == "all":
return torch.eq(x, self.missing_values).float()
else:
return torch.eq(torch.index_select(x, 1, self.column_indices), self.missing_values).float()

Просмотреть файл

@ -0,0 +1,44 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""
Converter for ONNX-ML Imputer.
"""
import numpy as np
from onnxconverter_common.registration import register_converter
from .._imputer_implementations import SimpleImputer
def convert_onnx_imputer(operator, device=None, extra_config={}):
"""
Converter for `ai.onnx.ml.Imputer`
Args:
operator: An operator wrapping a `ai.onnx.ml.Imputer` model
device: String defining the type of device the converted operator should be run on
extra_config: Extra configuration used to select the best conversion strategy
Returns:
A PyTorch model
"""
stats = missing = None
for attr in operator.raw_operator.origin.attribute:
if attr.name == "imputed_value_floats":
stats = np.array(attr.floats).astype("float64")
elif attr.name == "replaced_value_float":
missing = attr.f
if any(v is None for v in [stats, missing]):
raise RuntimeError("Error parsing Imputer, found unexpected None. stats: {}, missing: {}", stats, missing)
# ONNXML has no "strategy" field, but always behaves similar to SKL's constant: "replace missing values with fill_value"
return SimpleImputer(operator, device, statistics=stats, missing=missing, strategy="constant")
register_converter("ONNXMLImputer", convert_onnx_imputer)

Просмотреть файл

@ -12,39 +12,7 @@ import numpy as np
from onnxconverter_common.registration import register_converter
import torch
class SimpleImputer(PhysicalOperator, torch.nn.Module):
"""
Class implementing SimpleImputer operators in PyTorch.
"""
def __init__(self, logical_operator, device):
super(SimpleImputer, self).__init__(logical_operator)
sklearn_imputer = logical_operator.raw_operator
stats = [float(stat) for stat in sklearn_imputer.statistics_ if isinstance(stat, float)]
b_mask = np.logical_not(np.isnan(stats))
i_mask = [i for i in range(len(b_mask)) if b_mask[i]]
self.transformer = True
self.do_mask = sklearn_imputer.strategy == "constant" or all(b_mask)
self.mask = torch.nn.Parameter(torch.LongTensor([] if self.do_mask else i_mask), requires_grad=False)
self.replace_values = torch.nn.Parameter(
torch.tensor([sklearn_imputer.statistics_], dtype=torch.float32), requires_grad=False
)
self.is_nan = True if (sklearn_imputer.missing_values == "NaN" or np.isnan(sklearn_imputer.missing_values)) else False
if not self.is_nan:
self.missing_values = torch.nn.Parameter(
torch.tensor([sklearn_imputer.missing_values], dtype=torch.float32), requires_grad=False
)
def forward(self, x):
if self.is_nan:
result = torch.where(torch.isnan(x), self.replace_values.expand(x.shape), x)
if self.do_mask:
return result
return torch.index_select(result, 1, self.mask)
else:
return torch.where(torch.eq(x, self.missing_values), self.replace_values.expand(x.shape), x)
from .._imputer_implementations import SimpleImputer, MissingIndicator
def convert_sklearn_simple_imputer(operator, device, extra_config):
@ -64,35 +32,6 @@ def convert_sklearn_simple_imputer(operator, device, extra_config):
return SimpleImputer(operator, device)
class MissingIndicator(PhysicalOperator, torch.nn.Module):
"""
Class implementing Imputer operators in MissingIndicator.
"""
def __init__(self, logical_operator, device):
super(MissingIndicator, self).__init__(logical_operator)
sklearn_missing_indicator = logical_operator.raw_operator
self.transformer = True
self.missing_values = torch.nn.Parameter(
torch.tensor([sklearn_missing_indicator.missing_values], dtype=torch.float32), requires_grad=False
)
self.features = sklearn_missing_indicator.features
self.is_nan = True if (sklearn_missing_indicator.missing_values in ["NaN", None, np.nan]) else False
self.column_indices = torch.nn.Parameter(torch.LongTensor(sklearn_missing_indicator.features_), requires_grad=False)
def forward(self, x):
if self.is_nan:
if self.features == "all":
return torch.isnan(x).float()
else:
return torch.isnan(torch.index_select(x, 1, self.column_indices)).float()
else:
if self.features == "all":
return torch.eq(x, self.missing_values).float()
else:
return torch.eq(torch.index_select(x, 1, self.column_indices), self.missing_values).float()
def convert_sklearn_missing_indicator(operator, device, extra_config):
"""
Converter for `sklearn.impute.MissingIndicator`

Просмотреть файл

@ -77,6 +77,7 @@ Binarizer,
Cast,
Concat,
Div,
Imputer,
LabelEncoder,
Less,
LinearClassifier,
@ -313,6 +314,7 @@ def _build_onnxml_operator_list():
"ArrayFeatureExtractor",
"Binarizer",
"FeatureVectorizer",
"Imputer",
"LabelEncoder",
"OneHotEncoder",
"Normalizer",

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -22,7 +22,7 @@ with open(README) as f:
long_description = long_description[start_pos:]
install_requires = [
"numpy>=1.15,<=1.19.4",
"numpy>=1.15,<=1.20.*",
"onnxconverter-common>=1.6.0,<=1.7.0",
"scipy<=1.5.4",
"scikit-learn>=0.21.3,<=0.23.2",

Просмотреть файл

@ -0,0 +1,101 @@
"""
Tests onnxml Imputer converter
"""
import unittest
import warnings
import numpy as np
import torch
from sklearn.impute import SimpleImputer
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed, lightgbm_installed
from hummingbird.ml import convert
if onnx_runtime_installed():
import onnxruntime as ort
if onnx_ml_tools_installed():
from onnxmltools import convert_sklearn
from onnxmltools.convert.common.data_types import FloatTensorType as FloatTensorType_onnx
class TestONNXImputer(unittest.TestCase):
def _test_imputer_converter(self, model, mode="onnx"):
warnings.filterwarnings("ignore")
X = np.array([[1, 2], [np.nan, 3], [7, 6]], dtype=np.float32)
model.fit(X)
# Create ONNX-ML model
onnx_ml_model = convert_sklearn(model, initial_types=[("float_input", FloatTensorType_onnx(X.shape))])
# Get the predictions for the ONNX-ML model
session = ort.InferenceSession(onnx_ml_model.SerializeToString())
output_names = [session.get_outputs()[i].name for i in range(len(session.get_outputs()))]
inputs = {session.get_inputs()[0].name: X}
onnx_ml_pred = session.run(output_names, inputs)[0]
# Create test model by calling converter
model = convert(onnx_ml_model, mode, X)
# Get the predictions for the test model
pred = model.transform(X)
return onnx_ml_pred, pred
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
)
def test_onnx_imputer_const(self, rtol=1e-06, atol=1e-06):
model = SimpleImputer(strategy="constant")
onnx_ml_pred, onnx_pred = self._test_imputer_converter(model)
# Check that predicted values match
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
)
def test_onnx_imputer_const_nan0(self, rtol=1e-06, atol=1e-06):
model = SimpleImputer(strategy="constant", fill_value=0)
onnx_ml_pred, onnx_pred = self._test_imputer_converter(model)
# Check that predicted values match
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
)
def test_onnx_imputer_mean(self, rtol=1e-06, atol=1e-06):
model = SimpleImputer(strategy="mean", fill_value="nan")
onnx_ml_pred, onnx_pred = self._test_imputer_converter(model)
# Check that predicted values match
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
)
def test_onnx_imputer_converter_raises_rt(self):
warnings.filterwarnings("ignore")
model = SimpleImputer(strategy="mean", fill_value="nan")
X = np.array([[1, 2], [np.nan, 3], [7, 6]], dtype=np.float32)
model.fit(X)
# Create ONNX-ML model
onnx_ml_model = convert_sklearn(model, initial_types=[("float_input", FloatTensorType_onnx(X.shape))])
onnx_ml_model.graph.node[0].attribute[0].name = "".encode()
self.assertRaises(RuntimeError, convert, onnx_ml_model, "onnx", X)
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
)
def test_onnx_imputer_torch(self, rtol=1e-06, atol=1e-06):
model = SimpleImputer(strategy="constant")
onnx_ml_pred, onnx_pred = self._test_imputer_converter(model, mode="torch")
# Check that predicted values match
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -52,6 +52,12 @@ class TestSklearnTreeConverter(unittest.TestCase):
self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
from distutils.version import LooseVersion
import torch
if LooseVersion(torch.__version__) >= LooseVersion("1.7.0"):
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
# Random forest binary classifier
def test_random_forest_classifier_binary_converter(self):
self._run_tree_classification_converter(RandomForestClassifier, 2, n_estimators=10)