onnxml linearregressor (#194)
* placeholders * todo, clarify why onnxml logistic regression name doesn't line up w SKL * tests for regressor * adding test for parsefail * addressing matteo's feedback * adding back comment
This commit is contained in:
Родитель
e50eb9cf03
Коммит
cfb08f8f51
|
@ -31,7 +31,6 @@ def convert_onnx_linear_model(operator, device=None, extra_config={}):
|
|||
|
||||
operator = operator.raw_operator
|
||||
coefficients = intercepts = classes = multi_class = None
|
||||
is_linear_regression = False
|
||||
|
||||
for attr in operator.origin.attribute:
|
||||
if attr.name == "coefficients":
|
||||
|
@ -45,12 +44,9 @@ def convert_onnx_linear_model(operator, device=None, extra_config={}):
|
|||
multi_class = "multinomial"
|
||||
|
||||
if any(v is None for v in [coefficients, intercepts, classes]):
|
||||
print("coefficients{}, intercepts{}, classes{}".format(coefficients, intercepts, classes))
|
||||
raise RuntimeError("Error parsing LinearClassifier, found unexpected None")
|
||||
if multi_class is None: # if 'multi_class' attr was not present
|
||||
multi_class = "none" if len(classes) < 3 else "ovr"
|
||||
if operator.op_type == "LinearRegressor":
|
||||
is_linear_regression = True
|
||||
|
||||
# Now reshape the coefficients/intercepts
|
||||
if len(classes) == 2:
|
||||
|
@ -70,9 +66,35 @@ def convert_onnx_linear_model(operator, device=None, extra_config={}):
|
|||
coefficients = np.array(list(zip(*tmp)))
|
||||
else:
|
||||
raise RuntimeError("Error parsing LinearClassifier, length of classes {} unexpected:{}".format(len(classes), classes))
|
||||
return LinearModel(
|
||||
coefficients, intercepts, device, classes=classes, multi_class=multi_class, is_linear_regression=is_linear_regression
|
||||
)
|
||||
return LinearModel(coefficients, intercepts, device, classes=classes, multi_class=multi_class, is_linear_regression=False)
|
||||
|
||||
|
||||
def convert_onnx_linear_regression_model(operator, device, extra_config):
|
||||
"""
|
||||
Converter for `ai.onnx.ml.LinearRegression`
|
||||
Args:
|
||||
operator: An operator wrapping a `ai.onnx.ml.LinearRegression` model
|
||||
device: String defining the type of device the converted operator should be run on
|
||||
extra_config: Extra configuration used to select the best conversion strategy
|
||||
Returns:
|
||||
A PyTorch model
|
||||
"""
|
||||
assert operator is not None
|
||||
|
||||
operator = operator.raw_operator
|
||||
coefficients = intercepts = None
|
||||
for attr in operator.origin.attribute:
|
||||
|
||||
if attr.name == "coefficients":
|
||||
coefficients = np.array([[np.array(val).astype("float32")] for val in attr.floats]).astype("float32")
|
||||
elif attr.name == "intercepts":
|
||||
intercepts = np.array(attr.floats).astype("float32")
|
||||
|
||||
if any(v is None for v in [coefficients, intercepts]):
|
||||
raise RuntimeError("Error parsing LinearRegression, found unexpected None")
|
||||
|
||||
return LinearModel(coefficients, intercepts, device, is_linear_regression=True)
|
||||
|
||||
|
||||
register_converter("ONNXMLLinearClassifier", convert_onnx_linear_model)
|
||||
register_converter("ONNXMLLinearRegressor", convert_onnx_linear_regression_model)
|
||||
|
|
|
@ -154,6 +154,7 @@ def _build_onnxml_operator_list():
|
|||
return [
|
||||
# Linear-based models
|
||||
"LinearClassifier",
|
||||
"LinearRegressor",
|
||||
# ONNX operators.
|
||||
"Cast",
|
||||
# Preprocessing
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Tests onnxml Normalizer converter
|
||||
Tests onnxml Linear converters
|
||||
"""
|
||||
import unittest
|
||||
import warnings
|
||||
|
@ -19,8 +19,14 @@ if onnx_ml_tools_installed():
|
|||
from onnxmltools.convert.common.data_types import FloatTensorType as FloatTensorType_onnx
|
||||
|
||||
|
||||
class TestSklearnNormalizer(unittest.TestCase):
|
||||
def _test_regressor(self, classes):
|
||||
class TestONNXLinear(unittest.TestCase):
|
||||
def _test_linear(self, classes):
|
||||
"""
|
||||
This helper function tests conversion of `ai.onnx.ml.LinearClassifier`
|
||||
which is created from a scikit-learn LogisticRegression.
|
||||
|
||||
This tests `convert_onnx_linear_model` in `hummingbird.ml.operator_converters.onnxml_linear`
|
||||
"""
|
||||
n_features = 20
|
||||
n_total = 100
|
||||
np.random.seed(0)
|
||||
|
@ -66,8 +72,9 @@ class TestSklearnNormalizer(unittest.TestCase):
|
|||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||
)
|
||||
# test ai.onnx.ml.LinearClassifier with 2 classes
|
||||
def test_logistic_regression_onnxml_binary(self, rtol=1e-06, atol=1e-06):
|
||||
onnx_ml_pred, onnx_pred = self._test_regressor(2)
|
||||
onnx_ml_pred, onnx_pred = self._test_linear(2)
|
||||
|
||||
# Check that predicted values match
|
||||
np.testing.assert_allclose(onnx_ml_pred[1], onnx_pred[1], rtol=rtol, atol=atol) # labels
|
||||
|
@ -78,8 +85,9 @@ class TestSklearnNormalizer(unittest.TestCase):
|
|||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||
)
|
||||
# test ai.onnx.ml.LinearClassifier with 3 classes
|
||||
def test_logistic_regression_onnxml_multi(self, rtol=1e-06, atol=1e-06):
|
||||
onnx_ml_pred, onnx_pred = self._test_regressor(3)
|
||||
onnx_ml_pred, onnx_pred = self._test_linear(3)
|
||||
|
||||
# Check that predicted values match
|
||||
np.testing.assert_allclose(onnx_ml_pred[1], onnx_pred[1], rtol=rtol, atol=atol) # labels
|
||||
|
@ -87,6 +95,86 @@ class TestSklearnNormalizer(unittest.TestCase):
|
|||
list(map(lambda x: list(x.values()), onnx_ml_pred[0])), onnx_pred[0], rtol=rtol, atol=atol
|
||||
) # probs
|
||||
|
||||
def _test_regressor(self, values):
|
||||
"""
|
||||
This helper function tests conversion of `ai.onnx.ml.LinearRegressor`
|
||||
which is created from a scikit-learn LinearRegression.
|
||||
|
||||
This tests `convert_onnx_linear_regression_model` in `hummingbird.ml.operator_converters.onnxml_linear`
|
||||
"""
|
||||
n_features = 20
|
||||
n_total = 100
|
||||
np.random.seed(0)
|
||||
warnings.filterwarnings("ignore")
|
||||
X = np.random.rand(n_total, n_features)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(values, size=n_total)
|
||||
|
||||
# Create SKL model for testing
|
||||
model = LinearRegression()
|
||||
model.fit(X, y)
|
||||
|
||||
# Create ONNX-ML model
|
||||
onnx_ml_model = convert_sklearn(model, initial_types=[("float_input", FloatTensorType_onnx(X.shape))])
|
||||
|
||||
# Create ONNX model by calling converter
|
||||
onnx_model = convert(onnx_ml_model, "onnx", X)
|
||||
|
||||
# Get the predictions for the ONNX-ML model
|
||||
session = ort.InferenceSession(onnx_ml_model.SerializeToString())
|
||||
output_names = [session.get_outputs()[i].name for i in range(len(session.get_outputs()))]
|
||||
onnx_ml_pred = [[] for i in range(len(output_names))]
|
||||
inputs = {session.get_inputs()[0].name: X}
|
||||
onnx_ml_pred = session.run(output_names, inputs)
|
||||
|
||||
# Get the predictions for the ONNX model
|
||||
session = ort.InferenceSession(onnx_model.SerializeToString())
|
||||
onnx_pred = [[] for i in range(len(output_names))]
|
||||
onnx_pred = session.run(output_names, inputs)
|
||||
|
||||
return onnx_ml_pred, onnx_pred
|
||||
|
||||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||
)
|
||||
# test ai.onnx.ml.LinearRegressor with 2 values
|
||||
def test_linear_regression_onnxml_small(self, rtol=1e-06, atol=1e-06):
|
||||
onnx_ml_pred, onnx_pred = self._test_regressor(2)
|
||||
|
||||
# Check that predicted values match
|
||||
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
|
||||
|
||||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||
)
|
||||
# test ai.onnx.ml.LinearRegressor with 100 values
|
||||
def test_linear_regression_onnxml_large(self, rtol=1e-06, atol=1e-06):
|
||||
onnx_ml_pred, onnx_pred = self._test_regressor(100)
|
||||
|
||||
# Check that predicted values match
|
||||
np.testing.assert_allclose(onnx_ml_pred, onnx_pred, rtol=rtol, atol=atol)
|
||||
|
||||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||
)
|
||||
# test for malformed model/problem with parsing
|
||||
def test_onnx_linear_converter_raises_rt(self):
|
||||
n_features = 20
|
||||
n_total = 100
|
||||
np.random.seed(0)
|
||||
warnings.filterwarnings("ignore")
|
||||
X = np.random.rand(n_total, n_features)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(3, size=n_total)
|
||||
model = LinearRegression()
|
||||
model.fit(X, y)
|
||||
|
||||
# generate test input
|
||||
onnx_ml_model = convert_sklearn(model, initial_types=[("float_input", FloatTensorType_onnx(X.shape))])
|
||||
onnx_ml_model.graph.node[0].attribute[0].name = "".encode()
|
||||
|
||||
self.assertRaises(RuntimeError, convert, onnx_ml_model, "onnx", X)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -18,7 +18,7 @@ if onnx_ml_tools_installed():
|
|||
from onnxmltools.convert.common.data_types import FloatTensorType as FloatTensorType_onnx
|
||||
|
||||
|
||||
class TestSklearnNormalizer(unittest.TestCase):
|
||||
class TestONNXNormalizer(unittest.TestCase):
|
||||
def _test_normalizer_converter(self, norm):
|
||||
warnings.filterwarnings("ignore")
|
||||
X = np.array([[1, 2, 3], [4, 3, 0], [0, 1, 4], [0, 5, 6]], dtype=np.float32)
|
||||
|
|
Загрузка…
Ссылка в новой задаче