onnx svc (#456)
* add support for modified huber loss in sgd * fix typo * better errors when backend is not available * add test * add fi to convert * fix arrayfeatureextractor add few more onnx operators * addressing comments * bringing in SVC code, not yet onnx ops * adding supported. all onnx should be in place * rebase * bringing in new API changes * passing correct operator * fix to onnx operators * fixes for onnxml svc multiclass * documenting onnx bug * matteo's comments * matteo's comments/simplify test Co-authored-by: Matteo Interlandi <mainterl@microsoft.com>
This commit is contained in:
Родитель
1ce78eb757
Коммит
5a89414897
|
@ -26,6 +26,7 @@ from .onnx import linear as onnx_linear # noqa: E402, F811
|
|||
from .onnx import normalizer as onnx_normalizer # noqa: E402, F811
|
||||
from .onnx import one_hot_encoder as onnx_ohe # noqa: E402, F811
|
||||
from .onnx import scaler as onnx_scaler # noqa: E402, F811
|
||||
from .onnx import sv as onnx_sv # noqa: E402, F811
|
||||
from .onnx import tree_ensemble # noqa: E402
|
||||
from .sklearn import array_feature_extractor as sklearn_afe # noqa: E402
|
||||
from .sklearn import decision_tree # noqa: E402
|
||||
|
|
|
@ -32,6 +32,8 @@ class ArrayFeatureExtractor(PhysicalOperator, torch.nn.Module):
|
|||
self.is_contiguous = is_contiguous
|
||||
|
||||
def forward(self, x):
|
||||
if len(x.shape) == 1:
|
||||
x = x.view(1, -1)
|
||||
if self.is_contiguous:
|
||||
return x[:, self.min : self.max]
|
||||
else:
|
||||
|
|
|
@ -35,4 +35,4 @@ class Concat(PhysicalOperator, torch.nn.Module):
|
|||
)
|
||||
return torch.cat(x, dim=1)
|
||||
else:
|
||||
return torch.stack(x, dim=1)
|
||||
return torch.stack([i.view(-1) for i in x], dim=1)
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) 2020 Supun Nakandala. All Rights Reserved.
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
Base class for SV implementation.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import torch
|
||||
|
||||
from ._physical_operator import PhysicalOperator
|
||||
|
||||
|
||||
class SVC(PhysicalOperator, torch.nn.Module):
|
||||
def __init__(self, logical_operator, kernel, degree, sv, nv, a, b, gamma, coef0, classes, device):
|
||||
super(SVC, self).__init__(logical_operator, classification=True)
|
||||
self.kernel = kernel
|
||||
self.degree = degree
|
||||
self.gamma = gamma
|
||||
self.regression = False
|
||||
sv = sv.toarray() if type(sv) == scipy.sparse.csr.csr_matrix else sv
|
||||
self.sv = torch.nn.Parameter(torch.from_numpy(sv).double(), requires_grad=False)
|
||||
self.sv_t = torch.nn.Parameter(torch.transpose(self.sv, 0, 1), requires_grad=False)
|
||||
self.sv_norm = torch.nn.Parameter(-self.gamma * (self.sv ** 2).sum(1).view(1, -1), requires_grad=False)
|
||||
self.coef0 = coef0
|
||||
self.n_features = sv.shape[1]
|
||||
self.a = a
|
||||
self.b = torch.nn.Parameter(torch.from_numpy(b.reshape(1, -1)).double(), requires_grad=False)
|
||||
self.start = [sum(nv[:i]) for i in range(len(nv))]
|
||||
self.end = [self.start[i] + nv[i] for i in range(len(nv))]
|
||||
self.len_nv = len(nv)
|
||||
true_classes, false_classes = zip(*[(i, j) for i in range(self.len_nv) for j in range(i + 1, self.len_nv)])
|
||||
self.true_classes = torch.nn.Parameter(torch.IntTensor([true_classes]), requires_grad=False)
|
||||
self.false_classes = torch.nn.Parameter(torch.IntTensor([false_classes]), requires_grad=False)
|
||||
self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False)
|
||||
self.perform_class_select = False
|
||||
if min(classes) != 0 or max(classes) != len(classes) - 1:
|
||||
self.perform_class_select = True
|
||||
self.n_classes = len(classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.double()
|
||||
|
||||
if self.kernel == "linear":
|
||||
k = torch.mm(x, self.sv_t)
|
||||
elif self.kernel == "rbf":
|
||||
# using quadratic expansion--susseptible to rounding-off errors
|
||||
# http://www.robots.ox.ac.uk/~albanie/notes/Euclidean_distance_trick.pdf
|
||||
x_norm = -self.gamma * (x ** 2).sum(1).view(-1, 1)
|
||||
k = torch.exp(x_norm + self.sv_norm + 2.0 * self.gamma * torch.mm(x, self.sv_t).double())
|
||||
elif self.kernel == "sigmoid":
|
||||
k = torch.sigmoid(self.gamma * torch.mm(x, self.sv_t) + self.coef0)
|
||||
else: # poly kernel
|
||||
k = torch.pow(self.gamma * torch.mm(x, self.sv_t) + self.coef0, self.degree)
|
||||
|
||||
c = [
|
||||
sum(self.a[i, p] * k[:, p : p + 1] for p in range(self.start[j], self.end[j]))
|
||||
+ sum(self.a[j - 1, p] * k[:, p : p + 1] for p in range(self.start[i], self.end[i]))
|
||||
for i in range(self.len_nv)
|
||||
for j in range(i + 1, self.len_nv)
|
||||
]
|
||||
c = torch.cat(c, dim=1) + self.b
|
||||
if self.n_classes == 2:
|
||||
class_ids = torch.gt(c, 0.0).int().flatten()
|
||||
else:
|
||||
votes = torch.where(c > 0, self.true_classes, self.false_classes)
|
||||
# TODO mode is still not implemented for GPU backend.
|
||||
votes = votes.data.cpu()
|
||||
class_ids, _ = torch.mode(votes, dim=1)
|
||||
# No class probabilities in SVC.
|
||||
if self.perform_class_select:
|
||||
temp = torch.index_select(self.classes, 0, class_ids.long())
|
||||
return temp, temp
|
||||
else:
|
||||
return class_ids, class_ids
|
|
@ -64,15 +64,19 @@ class Sum(PhysicalOperator, torch.nn.Module):
|
|||
def __init__(self, logical_operator):
|
||||
super(Sum, self).__init__(logical_operator)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.sum(x).view(1)
|
||||
def forward(self, *x):
|
||||
if len(x) > 1:
|
||||
x = torch.cat(x, dim=1)
|
||||
return torch.sum(*x)
|
||||
|
||||
|
||||
class Add(PhysicalOperator, torch.nn.Module):
|
||||
def __init__(self, logical_operator, val):
|
||||
super(Add, self).__init__(logical_operator)
|
||||
|
||||
self.val = torch.nn.Parameter(torch.FloatTensor(val), requires_grad=False)
|
||||
if val is not None:
|
||||
assert len(self.inputs) == 1, "Unexpected input length for Add val"
|
||||
self.val = torch.nn.Parameter(torch.FloatTensor(val), requires_grad=False)
|
||||
|
||||
def forward(self, *x):
|
||||
if len(x) == 1:
|
||||
|
@ -84,7 +88,7 @@ class Less(PhysicalOperator, torch.nn.Module):
|
|||
def __init__(self, logical_operator, val):
|
||||
super(Less, self).__init__(logical_operator)
|
||||
|
||||
self.val = val
|
||||
self.val = torch.nn.Parameter(torch.FloatTensor(val), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.lt(x, self.val)
|
||||
|
@ -94,8 +98,8 @@ class Neg(PhysicalOperator, torch.nn.Module):
|
|||
def __init__(self, logical_operator):
|
||||
super(Neg, self).__init__(logical_operator)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.neg(x).view(-1)
|
||||
def forward(self, *x):
|
||||
return torch.neg(*x)
|
||||
|
||||
|
||||
class Abs(PhysicalOperator, torch.nn.Module):
|
||||
|
@ -110,10 +114,14 @@ class Mul(PhysicalOperator, torch.nn.Module):
|
|||
def __init__(self, logical_operator, val):
|
||||
super(Mul, self).__init__(logical_operator)
|
||||
|
||||
self.val = val
|
||||
if val is not None:
|
||||
assert len(self.inputs) == 1, "Unexpected input length for Mul val"
|
||||
self.val = torch.nn.Parameter(torch.FloatTensor(val), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.mul(x, self.val)
|
||||
def forward(self, *x):
|
||||
if len(x) == 1:
|
||||
return torch.mul(*x, self.val)
|
||||
return torch.mul(*x)
|
||||
|
||||
|
||||
class MatMul(PhysicalOperator, torch.nn.Module):
|
||||
|
@ -256,7 +264,9 @@ def convert_onnx_add(operator, device=None, extra_config={}):
|
|||
assert operator is not None
|
||||
|
||||
initializers = extra_config[constants.ONNX_INITIALIZERS]
|
||||
val = list(initializers[operator.raw_operator.origin.input[1]].float_data)
|
||||
val = None
|
||||
if operator.raw_operator.origin.input[1] in initializers:
|
||||
val = list(initializers[operator.raw_operator.origin.input[1]].float_data)
|
||||
|
||||
# Generate the model.
|
||||
return Add(operator, val)
|
||||
|
@ -313,7 +323,9 @@ def convert_onnx_mul(operator, device=None, extra_config={}):
|
|||
assert operator is not None
|
||||
|
||||
initializers = extra_config[constants.ONNX_INITIALIZERS]
|
||||
val = list(initializers[operator.raw_operator.origin.input[1]].float_data)
|
||||
val = None
|
||||
if operator.raw_operator.origin.input[1] in initializers:
|
||||
val = list(initializers[operator.raw_operator.origin.input[1]].float_data)
|
||||
|
||||
# Generate the model.
|
||||
return Mul(operator, val)
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
Converters for ONNX-ML SV models.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from onnxconverter_common.registration import register_converter
|
||||
|
||||
from .._sv_implementations import SVC
|
||||
|
||||
|
||||
def convert_onnx_svm_classifier_model(operator, device, extra_config):
|
||||
"""
|
||||
Converter for `ai.onnx.ml.SVMClassifier`
|
||||
|
||||
Args:
|
||||
operator: An operator wrapping a `ai.onnx.ml.SVMClassifier` model
|
||||
device: String defining the type of device the converted operator should be run on
|
||||
extra_config: Extra configuration used to select the best conversion strategy
|
||||
|
||||
Returns:
|
||||
A PyTorch model
|
||||
"""
|
||||
|
||||
# These are passed as params to SVC()
|
||||
kernel = degree = sv = nv = a = b = gamma = coef0 = classes = None
|
||||
|
||||
# These are stored for reshaping after parsing is done
|
||||
sv_vals = coeffis = None
|
||||
|
||||
for attr in operator.raw_operator.origin.attribute:
|
||||
|
||||
if attr.name == "kernel_type":
|
||||
# ex: Convert b'RBF' to 'rbf' for consistency
|
||||
kernel = attr.s.lower().decode("UTF-8")
|
||||
if kernel not in ["linear", "poly", "rbf"]: # from svc.py ln 58
|
||||
raise RuntimeError("Unsupported kernel for SVC: {}".format(kernel))
|
||||
|
||||
elif attr.name == "coefficients":
|
||||
coeffis = np.array(attr.floats)
|
||||
|
||||
elif attr.name == "vectors_per_class":
|
||||
nv = np.array(attr.ints).astype("int32")
|
||||
|
||||
elif attr.name == "support_vectors":
|
||||
sv_vals = np.array(attr.floats)
|
||||
|
||||
elif attr.name == "rho":
|
||||
b = np.array(attr.floats)
|
||||
|
||||
elif attr.name == "kernel_params":
|
||||
# See
|
||||
# https://github.com/onnx/sklearn-onnx/blob/master/skl2onnx/operator_converters/support_vector_machines.py
|
||||
# for details on [op._gamma, op.coef0, op.degree]
|
||||
kp_arr = np.array(attr.floats)
|
||||
gamma = kp_arr[0]
|
||||
coef0 = kp_arr[1]
|
||||
degree = int(kp_arr[2])
|
||||
|
||||
elif attr.name == "classlabels_ints":
|
||||
classes = np.array(attr.ints)
|
||||
|
||||
if any(v is None for v in [sv_vals, coeffis]):
|
||||
raise RuntimeError("Error parsing SVC arrays, found unexpected None")
|
||||
|
||||
# Now that we have parsed the degree and lengths, reshape 'a' and 'sv'
|
||||
# For 'a', these are in 'dual' shape, so resize into 2:
|
||||
# https://github.com/onnx/sklearn-onnx/blob/master/skl2onnx/operator_converters/support_vector_machines.py#L41
|
||||
#
|
||||
# Except for when they're not...
|
||||
# https://stackoverflow.com/questions/22816646/the-dimension-of-dual-coef-in-sklearn-svc
|
||||
if len(classes) > 2:
|
||||
a = coeffis.reshape(2, len(coeffis) // 2)
|
||||
else: # if not in "dual" form with classes > 3 (binary), 'a' and 'b' are the inverse. Don't ask why.
|
||||
a = np.negative([coeffis])
|
||||
b = np.negative(b)
|
||||
|
||||
sv = sv_vals.reshape(len(a[0]), len(sv_vals) // len(a[0]))
|
||||
|
||||
if any(v is None for v in [kernel, degree, sv, nv, a, b, gamma, coef0, classes]):
|
||||
raise RuntimeError(
|
||||
"Error parsing SVC, found unexpected None. kernel{} degree{} sv{} nv{} a{} b{} gamma{} coef0{} classes{}".format(
|
||||
kernel, degree, sv, nv, a, b, gamma, coef0, classes
|
||||
)
|
||||
)
|
||||
|
||||
return SVC(operator, kernel, degree, sv, nv, a, b, gamma, coef0, classes, device)
|
||||
|
||||
|
||||
register_converter("ONNXMLSVMClassifier", convert_onnx_svm_classifier_model)
|
|
@ -8,76 +8,8 @@
|
|||
Converters for scikit-learn SV models: SVC, NuSVC. (LinearSVC is covered by linear_classifier.py).
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import torch
|
||||
from onnxconverter_common.registration import register_converter
|
||||
|
||||
from .._physical_operator import PhysicalOperator
|
||||
|
||||
|
||||
class SVC(PhysicalOperator, torch.nn.Module):
|
||||
def __init__(self, operator, kernel, degree, sv, nv, a, b, gamma, coef0, classes, device):
|
||||
super(SVC, self).__init__(operator, classification=True)
|
||||
self.kernel = kernel
|
||||
self.degree = degree
|
||||
self.gamma = gamma
|
||||
self.regression = False
|
||||
sv = sv.toarray() if type(sv) == scipy.sparse.csr.csr_matrix else sv
|
||||
self.sv = torch.nn.Parameter(torch.from_numpy(sv).double(), requires_grad=False)
|
||||
self.sv_t = torch.nn.Parameter(torch.transpose(self.sv, 0, 1), requires_grad=False)
|
||||
self.sv_norm = torch.nn.Parameter(-self.gamma * (self.sv ** 2).sum(1).view(1, -1), requires_grad=False)
|
||||
self.coef0 = coef0
|
||||
self.n_features = sv.shape[1]
|
||||
self.a = a
|
||||
self.b = torch.nn.Parameter(torch.from_numpy(b.reshape(1, -1)).double(), requires_grad=False)
|
||||
self.start = [sum(nv[:i]) for i in range(len(nv))]
|
||||
self.end = [self.start[i] + nv[i] for i in range(len(nv))]
|
||||
self.len_nv = len(nv)
|
||||
true_classes, false_classes = zip(*[(i, j) for i in range(self.len_nv) for j in range(i + 1, self.len_nv)])
|
||||
self.true_classes = torch.nn.Parameter(torch.IntTensor([true_classes]), requires_grad=False)
|
||||
self.false_classes = torch.nn.Parameter(torch.IntTensor([false_classes]), requires_grad=False)
|
||||
self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False)
|
||||
self.perform_class_select = False
|
||||
if min(classes) != 0 or max(classes) != len(classes) - 1:
|
||||
self.perform_class_select = True
|
||||
self.n_classes = len(classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.double()
|
||||
|
||||
if self.kernel == "linear":
|
||||
k = torch.mm(x, self.sv_t)
|
||||
elif self.kernel == "rbf":
|
||||
# using quadratic expansion--susseptible to rounding-off errors
|
||||
# http://www.robots.ox.ac.uk/~albanie/notes/Euclidean_distance_trick.pdf
|
||||
x_norm = -self.gamma * (x ** 2).sum(1).view(-1, 1)
|
||||
k = torch.exp(x_norm + self.sv_norm + 2.0 * self.gamma * torch.mm(x, self.sv_t).double())
|
||||
elif self.kernel == "sigmoid":
|
||||
k = torch.sigmoid(self.gamma * torch.mm(x, self.sv_t) + self.coef0)
|
||||
else: # poly kernel
|
||||
k = torch.pow(self.gamma * torch.mm(x, self.sv_t) + self.coef0, self.degree)
|
||||
|
||||
c = [
|
||||
sum(self.a[i, p] * k[:, p : p + 1] for p in range(self.start[j], self.end[j]))
|
||||
+ sum(self.a[j - 1, p] * k[:, p : p + 1] for p in range(self.start[i], self.end[i]))
|
||||
for i in range(self.len_nv)
|
||||
for j in range(i + 1, self.len_nv)
|
||||
]
|
||||
c = torch.cat(c, dim=1) + self.b
|
||||
if self.n_classes == 2:
|
||||
class_ids = torch.gt(c, 0.0).int().flatten()
|
||||
else:
|
||||
votes = torch.where(c > 0, self.true_classes, self.false_classes)
|
||||
# TODO mode is still not implemented for GPU backend.
|
||||
votes = votes.data.cpu()
|
||||
class_ids, _ = torch.mode(votes, dim=1)
|
||||
# No class probabilities in SVC.
|
||||
if self.perform_class_select:
|
||||
temp = torch.index_select(self.classes, 0, class_ids.long())
|
||||
return temp, temp
|
||||
else:
|
||||
return class_ids, class_ids
|
||||
from .._sv_implementations import SVC
|
||||
|
||||
|
||||
def convert_sklearn_svc_model(operator, device, extra_config):
|
||||
|
@ -92,8 +24,6 @@ def convert_sklearn_svc_model(operator, device, extra_config):
|
|||
Returns:
|
||||
A PyTorch model
|
||||
"""
|
||||
assert operator is not None, "Cannot convert None operator"
|
||||
|
||||
if operator.raw_operator.kernel in ["linear", "poly", "rbf", "sigmoid"]:
|
||||
# https://stackoverflow.com/questions/20113206/scikit-learn-svc-decision-function-and-predict
|
||||
kernel = operator.raw_operator.kernel
|
||||
|
|
|
@ -70,20 +70,27 @@ XGBRanker,
|
|||
XGBRegressor,
|
||||
|
||||
**Supported Operators (ONNX-ML)**
|
||||
"ArrayFeatureExtractor",
|
||||
"Binarizer"
|
||||
"Cast",
|
||||
"Concat",
|
||||
"FeatureVectorizer"
|
||||
"LabelEncoder",
|
||||
"LinearClassifier",
|
||||
"LinearRegressor",
|
||||
"OneHotEncoder",
|
||||
"Normalizer",
|
||||
"Reshape",
|
||||
"Scaler",
|
||||
"TreeEnsembleClassifier",
|
||||
"TreeEnsembleRegressor",
|
||||
Abs,
|
||||
Add,
|
||||
ArrayFeatureExtractor,
|
||||
Binarizer,
|
||||
Cast,
|
||||
Concat,
|
||||
Div,
|
||||
LabelEncoder,
|
||||
Less,
|
||||
LinearClassifier,
|
||||
LinearRegressor,
|
||||
Mul,
|
||||
Neg,
|
||||
Normalizer,
|
||||
OneHotEncoder,
|
||||
Reshape,
|
||||
Sum,
|
||||
Scaler,
|
||||
SVMClassifier,
|
||||
TreeEnsembleClassifier,
|
||||
TreeEnsembleRegressor,
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
|
@ -301,6 +308,7 @@ def _build_onnxml_operator_list():
|
|||
"Mul",
|
||||
"Neg",
|
||||
"Reshape",
|
||||
"Sum",
|
||||
# Preprocessing
|
||||
"ArrayFeatureExtractor",
|
||||
"Binarizer",
|
||||
|
@ -309,6 +317,7 @@ def _build_onnxml_operator_list():
|
|||
"OneHotEncoder",
|
||||
"Normalizer",
|
||||
"Scaler",
|
||||
"SVMClassifier",
|
||||
# Tree-based models
|
||||
"TreeEnsembleClassifier",
|
||||
"TreeEnsembleRegressor",
|
||||
|
|
|
@ -7,7 +7,6 @@ import warnings
|
|||
import numpy as np
|
||||
import torch
|
||||
from sklearn.linear_model import LinearRegression, LogisticRegression, SGDClassifier, LogisticRegressionCV
|
||||
from sklearn.svm import LinearSVC, SVC, NuSVC
|
||||
|
||||
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed, lightgbm_installed
|
||||
from hummingbird.ml import convert
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
"""
|
||||
Tests onnxml SV converters
|
||||
"""
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.svm import LinearSVC, SVC, NuSVC
|
||||
|
||||
from hummingbird.ml._utils import onnx_ml_tools_installed, onnx_runtime_installed, lightgbm_installed
|
||||
from hummingbird.ml import convert
|
||||
|
||||
if onnx_runtime_installed():
|
||||
import onnxruntime as ort
|
||||
if onnx_ml_tools_installed():
|
||||
from onnxmltools import convert_sklearn
|
||||
from onnxmltools.convert.common.data_types import FloatTensorType as FloatTensorType_onnx
|
||||
|
||||
|
||||
class TestONNXSVC(unittest.TestCase):
|
||||
def _test_sv(self, classes, mode="torch"):
|
||||
"""
|
||||
This helper function tests conversion of `ai.onnx.ml.SVMClassifier`
|
||||
which is created from a scikit-learn SVC.
|
||||
|
||||
This then calls either "_to_onnx" or "_to_torch"
|
||||
"""
|
||||
n_features = 20
|
||||
n_total = 100
|
||||
np.random.seed(0)
|
||||
warnings.filterwarnings("ignore")
|
||||
X = np.random.rand(n_total, n_features)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(classes, size=n_total)
|
||||
|
||||
# Create SKL model for testing
|
||||
model = SVC()
|
||||
model.fit(X, y)
|
||||
|
||||
# Create ONNX-ML model
|
||||
onnx_ml_model = convert_sklearn(model, initial_types=[("float_input", FloatTensorType_onnx(X.shape))])
|
||||
|
||||
# Get the predictions for the ONNX-ML model
|
||||
session = ort.InferenceSession(onnx_ml_model.SerializeToString())
|
||||
output_names = [session.get_outputs()[i].name for i in range(len(session.get_outputs()))]
|
||||
onnx_ml_pred = [[] for i in range(len(output_names))]
|
||||
inputs = {session.get_inputs()[0].name: X}
|
||||
pred = session.run(output_names, inputs)
|
||||
for i in range(len(output_names)):
|
||||
if "label" in output_names[i]:
|
||||
onnx_ml_pred[1] = pred[i]
|
||||
else:
|
||||
onnx_ml_pred[0] = pred[i]
|
||||
|
||||
model = convert(onnx_ml_model, mode, X)
|
||||
|
||||
pred = model.predict(X)
|
||||
|
||||
return onnx_ml_pred, pred
|
||||
|
||||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||
)
|
||||
# test ai.onnx.ml.SVMClassifier with 2 classes for onnxml-> pytorch
|
||||
def test_logistic_regression_onnxml_binary_torch(self, rtol=1e-06, atol=1e-06):
|
||||
onnx_ml_pred, pred = self._test_sv(2)
|
||||
|
||||
# Check that predicted values match
|
||||
np.testing.assert_allclose(onnx_ml_pred[1], pred, rtol=rtol, atol=atol)
|
||||
|
||||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||
)
|
||||
# test ai.onnx.ml.SVMClassifier with 3 classes for onnxml-> pytorch
|
||||
def test_logistic_regression_onnxml_multi_torch(self, rtol=1e-06, atol=1e-06):
|
||||
onnx_ml_pred, pred = self._test_sv(3)
|
||||
|
||||
# Check that predicted values match
|
||||
np.testing.assert_allclose(onnx_ml_pred[1], pred, rtol=rtol, atol=atol)
|
||||
|
||||
# TODO: There is a bug with ORT:
|
||||
# onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented:
|
||||
# [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node Gemm_8:Gemm(11)
|
||||
# @unittest.skipIf(
|
||||
# not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test requires ONNX, ORT and ONNXMLTOOLS"
|
||||
# )
|
||||
# # test ai.onnx.ml.SVMClassifier with 2 classes
|
||||
# def test_logistic_regression_onnxml_binary_onnx(self, rtol=1e-06, atol=1e-06):
|
||||
# onnx_ml_pred, onnx_pred = self._test_sv(2, mode="onnx")
|
||||
|
||||
# # Check that predicted values match
|
||||
# np.testing.assert_allclose(onnx_ml_pred[1], onnx_pred[1], rtol=rtol, atol=atol) # labels
|
||||
# np.testing.assert_allclose(
|
||||
# list(map(lambda x: list(x.values()), onnx_ml_pred[0])), onnx_pred[0], rtol=rtol, atol=atol
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче