Merge branch 'main' of https://github.com/microsoft/hummingbird into main
This commit is contained in:
Коммит
bdddf5c2cb
|
@ -91,7 +91,7 @@ In general, Hummingbird syntax is very intuitive and minimal. To run your tradit
|
||||||
```python
|
```python
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
from hummingbird.ml import convert
|
from hummingbird.ml import convert, load
|
||||||
|
|
||||||
# Create some random data for binary classification
|
# Create some random data for binary classification
|
||||||
num_classes = 2
|
num_classes = 2
|
||||||
|
@ -116,7 +116,7 @@ model.predict(X)
|
||||||
model.save('hb_model')
|
model.save('hb_model')
|
||||||
|
|
||||||
# Load the model back
|
# Load the model back
|
||||||
model = hummingbird.ml.load('hb_model')
|
model = load('hb_model')
|
||||||
```
|
```
|
||||||
|
|
||||||
# Documentation
|
# Documentation
|
||||||
|
|
|
@ -21,6 +21,7 @@ from .onnx import onnx_operator # noqa: E402
|
||||||
from .onnx import array_feature_extractor as onnx_afe # noqa: E402, F811
|
from .onnx import array_feature_extractor as onnx_afe # noqa: E402, F811
|
||||||
from .onnx import binarizer as onnx_binarizer # noqa: E402, F811
|
from .onnx import binarizer as onnx_binarizer # noqa: E402, F811
|
||||||
from .onnx import feature_vectorizer # noqa: E402
|
from .onnx import feature_vectorizer # noqa: E402
|
||||||
|
from .onnx import imputer as onnx_imputer # noqa: E402
|
||||||
from .onnx import label_encoder as onnx_label_encoder # noqa: E402, F811
|
from .onnx import label_encoder as onnx_label_encoder # noqa: E402, F811
|
||||||
from .onnx import linear as onnx_linear # noqa: E402, F811
|
from .onnx import linear as onnx_linear # noqa: E402, F811
|
||||||
from .onnx import normalizer as onnx_normalizer # noqa: E402, F811
|
from .onnx import normalizer as onnx_normalizer # noqa: E402, F811
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License. See License.txt in the project root for
|
||||||
|
# license information.
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
"""
|
||||||
|
Base classes for Imputers
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ._physical_operator import PhysicalOperator
|
||||||
|
from . import constants
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleImputer(PhysicalOperator, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Class implementing SimpleImputer operators in PyTorch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logical_operator, device, statistics=None, missing=None, strategy=None):
|
||||||
|
super(SimpleImputer, self).__init__(logical_operator)
|
||||||
|
sklearn_imputer = logical_operator.raw_operator
|
||||||
|
# Pull out the stats field from either the SKL imputer or args
|
||||||
|
stats_ = statistics if statistics is not None else sklearn_imputer.statistics_
|
||||||
|
# Process the stats into an array
|
||||||
|
stats = [float(stat) for stat in stats_]
|
||||||
|
|
||||||
|
missing_values = missing if missing is not None else sklearn_imputer.missing_values
|
||||||
|
strategy = strategy if strategy is not None else sklearn_imputer.strategy
|
||||||
|
|
||||||
|
b_mask = np.logical_not(np.isnan(stats))
|
||||||
|
i_mask = [i for i in range(len(b_mask)) if b_mask[i]]
|
||||||
|
self.transformer = True
|
||||||
|
self.do_mask = strategy == "constant" or all(b_mask)
|
||||||
|
self.mask = torch.nn.Parameter(torch.LongTensor([] if self.do_mask else i_mask), requires_grad=False)
|
||||||
|
self.replace_values = torch.nn.Parameter(torch.tensor([stats_], dtype=torch.float32), requires_grad=False)
|
||||||
|
|
||||||
|
self.is_nan = True if (missing_values == "NaN" or np.isnan(missing_values)) else False
|
||||||
|
if not self.is_nan:
|
||||||
|
self.missing_values = torch.nn.Parameter(torch.tensor([missing_values], dtype=torch.float32), requires_grad=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.is_nan:
|
||||||
|
result = torch.where(torch.isnan(x), self.replace_values.expand(x.shape), x)
|
||||||
|
if self.do_mask:
|
||||||
|
return result
|
||||||
|
return torch.index_select(result, 1, self.mask)
|
||||||
|
else:
|
||||||
|
return torch.where(torch.eq(x, self.missing_values), self.replace_values.expand(x.shape), x)
|
||||||
|
|
||||||
|
|
||||||
|
class MissingIndicator(PhysicalOperator, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Class implementing Imputer operators in MissingIndicator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logical_operator, device):
|
||||||
|
super(MissingIndicator, self).__init__(logical_operator)
|
||||||
|
sklearn_missing_indicator = logical_operator.raw_operator
|
||||||
|
self.transformer = True
|
||||||
|
self.missing_values = torch.nn.Parameter(
|
||||||
|
torch.tensor([sklearn_missing_indicator.missing_values], dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
self.features = sklearn_missing_indicator.features
|
||||||
|
self.is_nan = True if (sklearn_missing_indicator.missing_values in ["NaN", None, np.nan]) else False
|
||||||
|
self.column_indices = torch.nn.Parameter(torch.LongTensor(sklearn_missing_indicator.features_), requires_grad=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.is_nan:
|
||||||
|
if self.features == "all":
|
||||||
|
return torch.isnan(x).float()
|
||||||
|
else:
|
||||||
|
return torch.isnan(torch.index_select(x, 1, self.column_indices)).float()
|
||||||
|
else:
|
||||||
|
if self.features == "all":
|
||||||
|
return torch.eq(x, self.missing_values).float()
|
||||||
|
else:
|
||||||
|
return torch.eq(torch.index_select(x, 1, self.column_indices), self.missing_values).float()
|
|
@ -0,0 +1,44 @@
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License. See License.txt in the project root for
|
||||||
|
# license information.
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
"""
|
||||||
|
Converter for ONNX-ML Imputer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from onnxconverter_common.registration import register_converter
|
||||||
|
|
||||||
|
from .._imputer_implementations import SimpleImputer
|
||||||
|
|
||||||
|
|
||||||
|
def convert_onnx_imputer(operator, device=None, extra_config={}):
|
||||||
|
"""
|
||||||
|
Converter for `ai.onnx.ml.Imputer`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operator: An operator wrapping a `ai.onnx.ml.Imputer` model
|
||||||
|
device: String defining the type of device the converted operator should be run on
|
||||||
|
extra_config: Extra configuration used to select the best conversion strategy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A PyTorch model
|
||||||
|
"""
|
||||||
|
|
||||||
|
stats = missing = None
|
||||||
|
for attr in operator.raw_operator.origin.attribute:
|
||||||
|
if attr.name == "imputed_value_floats":
|
||||||
|
stats = np.array(attr.floats).astype("float64")
|
||||||
|
elif attr.name == "replaced_value_float":
|
||||||
|
missing = attr.f
|
||||||
|
|
||||||
|
if any(v is None for v in [stats, missing]):
|
||||||
|
raise RuntimeError("Error parsing Imputer, found unexpected None. stats: {}, missing: {}", stats, missing)
|
||||||
|
|
||||||
|
# ONNXML has no "strategy" field, but always behaves similar to SKL's constant: "replace missing values with fill_value"
|
||||||
|
return SimpleImputer(operator, device, statistics=stats, missing=missing, strategy="constant")
|
||||||
|
|
||||||
|
|
||||||
|
register_converter("ONNXMLImputer", convert_onnx_imputer)
|
|
@ -12,39 +12,7 @@ import numpy as np
|
||||||
from onnxconverter_common.registration import register_converter
|
from onnxconverter_common.registration import register_converter
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .._imputer_implementations import SimpleImputer, MissingIndicator
|
||||||
class SimpleImputer(PhysicalOperator, torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Class implementing SimpleImputer operators in PyTorch.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, logical_operator, device):
|
|
||||||
super(SimpleImputer, self).__init__(logical_operator)
|
|
||||||
sklearn_imputer = logical_operator.raw_operator
|
|
||||||
stats = [float(stat) for stat in sklearn_imputer.statistics_ if isinstance(stat, float)]
|
|
||||||
b_mask = np.logical_not(np.isnan(stats))
|
|
||||||
i_mask = [i for i in range(len(b_mask)) if b_mask[i]]
|
|
||||||
self.transformer = True
|
|
||||||
self.do_mask = sklearn_imputer.strategy == "constant" or all(b_mask)
|
|
||||||
self.mask = torch.nn.Parameter(torch.LongTensor([] if self.do_mask else i_mask), requires_grad=False)
|
|
||||||
self.replace_values = torch.nn.Parameter(
|
|
||||||
torch.tensor([sklearn_imputer.statistics_], dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.is_nan = True if (sklearn_imputer.missing_values == "NaN" or np.isnan(sklearn_imputer.missing_values)) else False
|
|
||||||
if not self.is_nan:
|
|
||||||
self.missing_values = torch.nn.Parameter(
|
|
||||||
torch.tensor([sklearn_imputer.missing_values], dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.is_nan:
|
|
||||||
result = torch.where(torch.isnan(x), self.replace_values.expand(x.shape), x)
|
|
||||||
if self.do_mask:
|
|
||||||
return result
|
|
||||||
return torch.index_select(result, 1, self.mask)
|
|
||||||
else:
|
|
||||||
return torch.where(torch.eq(x, self.missing_values), self.replace_values.expand(x.shape), x)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_sklearn_simple_imputer(operator, device, extra_config):
|
def convert_sklearn_simple_imputer(operator, device, extra_config):
|
||||||
|
@ -64,35 +32,6 @@ def convert_sklearn_simple_imputer(operator, device, extra_config):
|
||||||
return SimpleImputer(operator, device)
|
return SimpleImputer(operator, device)
|
||||||
|
|
||||||
|
|
||||||
class MissingIndicator(PhysicalOperator, torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Class implementing Imputer operators in MissingIndicator.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, logical_operator, device):
|
|
||||||
super(MissingIndicator, self).__init__(logical_operator)
|
|
||||||
sklearn_missing_indicator = logical_operator.raw_operator
|
|
||||||
self.transformer = True
|
|
||||||
self.missing_values = torch.nn.Parameter(
|
|
||||||
torch.tensor([sklearn_missing_indicator.missing_values], dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
self.features = sklearn_missing_indicator.features
|
|
||||||
self.is_nan = True if (sklearn_missing_indicator.missing_values in ["NaN", None, np.nan]) else False
|
|
||||||
self.column_indices = torch.nn.Parameter(torch.LongTensor(sklearn_missing_indicator.features_), requires_grad=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.is_nan:
|
|
||||||
if self.features == "all":
|
|
||||||
return torch.isnan(x).float()
|
|
||||||
else:
|
|
||||||
return torch.isnan(torch.index_select(x, 1, self.column_indices)).float()
|
|
||||||
else:
|
|
||||||
if self.features == "all":
|
|
||||||
return torch.eq(x, self.missing_values).float()
|
|
||||||
else:
|
|
||||||
return torch.eq(torch.index_select(x, 1, self.column_indices), self.missing_values).float()
|
|
||||||
|
|
||||||
|
|
||||||
def convert_sklearn_missing_indicator(operator, device, extra_config):
|
def convert_sklearn_missing_indicator(operator, device, extra_config):
|
||||||
"""
|
"""
|
||||||
Converter for `sklearn.impute.MissingIndicator`
|
Converter for `sklearn.impute.MissingIndicator`
|
||||||
|
|
|
@ -77,6 +77,7 @@ Binarizer,
|
||||||
Cast,
|
Cast,
|
||||||
Concat,
|
Concat,
|
||||||
Div,
|
Div,
|
||||||
|
Imputer,
|
||||||
LabelEncoder,
|
LabelEncoder,
|
||||||
Less,
|
Less,
|
||||||
LinearClassifier,
|
LinearClassifier,
|
||||||
|
@ -313,6 +314,7 @@ def _build_onnxml_operator_list():
|
||||||
"ArrayFeatureExtractor",
|
"ArrayFeatureExtractor",
|
||||||
"Binarizer",
|
"Binarizer",
|
||||||
"FeatureVectorizer",
|
"FeatureVectorizer",
|
||||||
|
"Imputer",
|
||||||
"LabelEncoder",
|
"LabelEncoder",
|
||||||
"OneHotEncoder",
|
"OneHotEncoder",
|
||||||
"Normalizer",
|
"Normalizer",
|
||||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
2
setup.py
2
setup.py
|
@ -22,7 +22,7 @@ with open(README) as f:
|
||||||
long_description = long_description[start_pos:]
|
long_description = long_description[start_pos:]
|
||||||
|
|
||||||
install_requires = [
|
install_requires = [
|
||||||
"numpy>=1.15,<=1.19.4",
|
"numpy>=1.15,<=1.20.*",
|
||||||
"onnxconverter-common>=1.6.0,<=1.7.0",
|
"onnxconverter-common>=1.6.0,<=1.7.0",
|
||||||
"scipy<=1.5.4",
|
"scipy<=1.5.4",
|
||||||
"scikit-learn>=0.21.3,<=0.23.2",
|
"scikit-learn>=0.21.3,<=0.23.2",
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
"""
|
||||||
|
Tests onnxml Imputer converter
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from sklearn.impute import SimpleImputer
|
||||||
|
|
||||||
|
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed, lightgbm_installed
|
||||||
|
from hummingbird.ml import convert
|
||||||
|
|
||||||
|
if onnx_runtime_installed():
|
||||||
|
import onnxruntime as ort
|
||||||
|
if onnx_ml_tools_installed():
|
||||||
|
from onnxmltools import convert_sklearn
|
||||||
|
from onnxmltools.convert.common.data_types import FloatTensorType as FloatTensorType_onnx
|
||||||
|
|
||||||
|
|
||||||
|
class TestONNXImputer(unittest.TestCase):
|
||||||
|
def _test_imputer_converter(self, model, mode="onnx"):
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
X = np.array([[1, 2], [np.nan, 3], [7, 6]], dtype=np.float32)
|
||||||
|
model.fit(X)
|
||||||
|
|
||||||
|
# Create ONNX-ML model
|
||||||
|
onnx_ml_model = convert_sklearn(model, initial_types=[("float_input", FloatTensorType_onnx(X.shape))])
|
||||||
|
|
||||||
|
# Get the predictions for the ONNX-ML model
|
||||||
|
session = ort.InferenceSession(onnx_ml_model.SerializeToString())
|
||||||
|
output_names = [session.get_outputs()[i].name for i in range(len(session.get_outputs()))]
|
||||||
|
inputs = {session.get_inputs()[0].name: X}
|
||||||
|
onnx_ml_pred = session.run(output_names, inputs)[0]
|
||||||
|
|
||||||
|
# Create test model by calling converter
|
||||||
|
model = convert(onnx_ml_model, mode, X)
|
||||||
|
|
||||||
|
# Get the predictions for the test model
|
||||||
|
pred = model.transform(X)
|
||||||
|
|
||||||
|
return onnx_ml_pred, pred
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||||
|
)
|
||||||
|
def test_onnx_imputer_const(self, rtol=1e-06, atol=1e-06):
|
||||||
|
model = SimpleImputer(strategy="constant")
|
||||||
|
onnx_ml_pred, onnx_pred = self._test_imputer_converter(model)
|
||||||
|
|
||||||
|
# Check that predicted values match
|
||||||
|
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||||
|
)
|
||||||
|
def test_onnx_imputer_const_nan0(self, rtol=1e-06, atol=1e-06):
|
||||||
|
model = SimpleImputer(strategy="constant", fill_value=0)
|
||||||
|
onnx_ml_pred, onnx_pred = self._test_imputer_converter(model)
|
||||||
|
|
||||||
|
# Check that predicted values match
|
||||||
|
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||||
|
)
|
||||||
|
def test_onnx_imputer_mean(self, rtol=1e-06, atol=1e-06):
|
||||||
|
model = SimpleImputer(strategy="mean", fill_value="nan")
|
||||||
|
onnx_ml_pred, onnx_pred = self._test_imputer_converter(model)
|
||||||
|
|
||||||
|
# Check that predicted values match
|
||||||
|
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||||
|
)
|
||||||
|
def test_onnx_imputer_converter_raises_rt(self):
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
model = SimpleImputer(strategy="mean", fill_value="nan")
|
||||||
|
X = np.array([[1, 2], [np.nan, 3], [7, 6]], dtype=np.float32)
|
||||||
|
model.fit(X)
|
||||||
|
|
||||||
|
# Create ONNX-ML model
|
||||||
|
onnx_ml_model = convert_sklearn(model, initial_types=[("float_input", FloatTensorType_onnx(X.shape))])
|
||||||
|
onnx_ml_model.graph.node[0].attribute[0].name = "".encode()
|
||||||
|
|
||||||
|
self.assertRaises(RuntimeError, convert, onnx_ml_model, "onnx", X)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||||
|
)
|
||||||
|
def test_onnx_imputer_torch(self, rtol=1e-06, atol=1e-06):
|
||||||
|
model = SimpleImputer(strategy="constant")
|
||||||
|
onnx_ml_pred, onnx_pred = self._test_imputer_converter(model, mode="torch")
|
||||||
|
|
||||||
|
# Check that predicted values match
|
||||||
|
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
|
@ -52,6 +52,12 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
||||||
self.assertIsNotNone(torch_model)
|
self.assertIsNotNone(torch_model)
|
||||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||||
|
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if LooseVersion(torch.__version__) >= LooseVersion("1.7.0"):
|
||||||
|
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||||
|
|
||||||
# Random forest binary classifier
|
# Random forest binary classifier
|
||||||
def test_random_forest_classifier_binary_converter(self):
|
def test_random_forest_classifier_binary_converter(self):
|
||||||
self._run_tree_classification_converter(RandomForestClassifier, 2, n_estimators=10)
|
self._run_tree_classification_converter(RandomForestClassifier, 2, n_estimators=10)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче