Родитель
55f63fe3a3
Коммит
7b717b98f3
|
@ -30,13 +30,13 @@ from .onnx import scaler as onnx_scaler # noqa: E402, F811
|
|||
from .onnx import sv as onnx_sv # noqa: E402, F811
|
||||
from .onnx import tree_ensemble # noqa: E402
|
||||
from .sklearn import array_feature_extractor as sklearn_afe # noqa: E402
|
||||
from .sklearn import cluster # noqa: E402
|
||||
from .sklearn import decision_tree # noqa: E402
|
||||
from .sklearn import decomposition # noqa: E402
|
||||
from .sklearn import discretizer as sklearn_discretizer # noqa: E402
|
||||
from .sklearn import gbdt # noqa: E402
|
||||
from .sklearn import iforest # noqa: E402
|
||||
from .sklearn import imputer # noqa: E402
|
||||
from .sklearn import kmeans # noqa: E402
|
||||
from .sklearn import kneighbors # noqa: E402
|
||||
from .sklearn import label_encoder # noqa: E402
|
||||
from .sklearn import linear as sklearn_linear # noqa: E402
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
# --------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
Converters for scikit-learn KMeans.
|
||||
Converters for scikit-learn KMeans and MeanShift models.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
@ -40,3 +40,4 @@ def convert_sklearn_kmeans_model(operator, device, extra_config):
|
|||
|
||||
|
||||
register_converter("SklearnKMeans", convert_sklearn_kmeans_model)
|
||||
register_converter("SklearnMeanShift", convert_sklearn_kmeans_model)
|
|
@ -40,6 +40,7 @@ LogisticRegression,
|
|||
LogisticRegressionCV,
|
||||
RidgeCV,
|
||||
MaxAbsScaler,
|
||||
MeanShift,
|
||||
MinMaxScaler,
|
||||
MissingIndicator,
|
||||
MLPClassifier,
|
||||
|
@ -156,7 +157,7 @@ def _build_sklearn_operator_list():
|
|||
from sklearn.neighbors import KNeighborsRegressor
|
||||
|
||||
# Clustering models
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.cluster import KMeans, MeanShift
|
||||
|
||||
# Preprocessing
|
||||
from sklearn.preprocessing import (
|
||||
|
@ -205,6 +206,7 @@ def _build_sklearn_operator_list():
|
|||
RidgeCV,
|
||||
# Clustering
|
||||
KMeans,
|
||||
MeanShift,
|
||||
# Other models
|
||||
BernoulliNB,
|
||||
GaussianNB,
|
||||
|
|
|
@ -12,8 +12,13 @@ import hummingbird.ml
|
|||
from hummingbird.ml import constants
|
||||
from hummingbird.ml._utils import tvm_installed
|
||||
|
||||
try:
|
||||
from sklearn.cluster import MeanShift
|
||||
except Exception:
|
||||
MeanShift = None
|
||||
|
||||
class TestSklearnKMeans(unittest.TestCase):
|
||||
|
||||
class TestSklearnClustering(unittest.TestCase):
|
||||
# KMeans test function to be parameterized
|
||||
def _test_kmeans(self, n_clusters, algorithm="full", backend="torch", extra_config={}):
|
||||
model = KMeans(n_clusters=n_clusters, algorithm=algorithm, random_state=0)
|
||||
|
@ -50,6 +55,31 @@ class TestSklearnKMeans(unittest.TestCase):
|
|||
def test_kmeans_5_elkan(self):
|
||||
self._test_kmeans(5, "elkan")
|
||||
|
||||
@unittest.skipIf(MeanShift is None, reason="MeanShift is supported in scikit-learn >= 0.22")
|
||||
def _test_mean_shift(self, bandwidth=None, backend="torch", extra_config={}):
|
||||
for cluster_all in [True, False]:
|
||||
model = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
|
||||
model.fit(X)
|
||||
torch_model = hummingbird.ml.convert(model, backend, X, extra_config=extra_config)
|
||||
self.assertTrue(torch_model is not None)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6)
|
||||
|
||||
# MeanShift default
|
||||
def test_mean_shift(self):
|
||||
self._test_mean_shift()
|
||||
|
||||
# MeanShift bandwdith 2.0
|
||||
def test_mean_shift_bandwdith(self):
|
||||
self._test_mean_shift(2.0)
|
||||
|
||||
# MeanShift bandwdith 5.0
|
||||
def test_mean_shift_bandwdith_5(self):
|
||||
self._test_mean_shift(5.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче