* add meanshift
This commit is contained in:
Matteo Interlandi 2021-03-25 09:44:34 -07:00 коммит произвёл GitHub
Родитель 55f63fe3a3
Коммит 7b717b98f3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 37 добавлений и 4 удалений

Просмотреть файл

@ -30,13 +30,13 @@ from .onnx import scaler as onnx_scaler # noqa: E402, F811
from .onnx import sv as onnx_sv # noqa: E402, F811
from .onnx import tree_ensemble # noqa: E402
from .sklearn import array_feature_extractor as sklearn_afe # noqa: E402
from .sklearn import cluster # noqa: E402
from .sklearn import decision_tree # noqa: E402
from .sklearn import decomposition # noqa: E402
from .sklearn import discretizer as sklearn_discretizer # noqa: E402
from .sklearn import gbdt # noqa: E402
from .sklearn import iforest # noqa: E402
from .sklearn import imputer # noqa: E402
from .sklearn import kmeans # noqa: E402
from .sklearn import kneighbors # noqa: E402
from .sklearn import label_encoder # noqa: E402
from .sklearn import linear as sklearn_linear # noqa: E402

Просмотреть файл

@ -6,7 +6,7 @@
# --------------------------------------------------------------------------
"""
Converters for scikit-learn KMeans.
Converters for scikit-learn KMeans and MeanShift models.
"""
import torch
@ -40,3 +40,4 @@ def convert_sklearn_kmeans_model(operator, device, extra_config):
register_converter("SklearnKMeans", convert_sklearn_kmeans_model)
register_converter("SklearnMeanShift", convert_sklearn_kmeans_model)

Просмотреть файл

@ -40,6 +40,7 @@ LogisticRegression,
LogisticRegressionCV,
RidgeCV,
MaxAbsScaler,
MeanShift,
MinMaxScaler,
MissingIndicator,
MLPClassifier,
@ -156,7 +157,7 @@ def _build_sklearn_operator_list():
from sklearn.neighbors import KNeighborsRegressor
# Clustering models
from sklearn.cluster import KMeans
from sklearn.cluster import KMeans, MeanShift
# Preprocessing
from sklearn.preprocessing import (
@ -205,6 +206,7 @@ def _build_sklearn_operator_list():
RidgeCV,
# Clustering
KMeans,
MeanShift,
# Other models
BernoulliNB,
GaussianNB,

Просмотреть файл

@ -12,8 +12,13 @@ import hummingbird.ml
from hummingbird.ml import constants
from hummingbird.ml._utils import tvm_installed
try:
from sklearn.cluster import MeanShift
except Exception:
MeanShift = None
class TestSklearnKMeans(unittest.TestCase):
class TestSklearnClustering(unittest.TestCase):
# KMeans test function to be parameterized
def _test_kmeans(self, n_clusters, algorithm="full", backend="torch", extra_config={}):
model = KMeans(n_clusters=n_clusters, algorithm=algorithm, random_state=0)
@ -50,6 +55,31 @@ class TestSklearnKMeans(unittest.TestCase):
def test_kmeans_5_elkan(self):
self._test_kmeans(5, "elkan")
@unittest.skipIf(MeanShift is None, reason="MeanShift is supported in scikit-learn >= 0.22")
def _test_mean_shift(self, bandwidth=None, backend="torch", extra_config={}):
for cluster_all in [True, False]:
model = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
model.fit(X)
torch_model = hummingbird.ml.convert(model, backend, X, extra_config=extra_config)
self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-6, atol=1e-6)
# MeanShift default
def test_mean_shift(self):
self._test_mean_shift()
# MeanShift bandwdith 2.0
def test_mean_shift_bandwdith(self):
self._test_mean_shift(2.0)
# MeanShift bandwdith 5.0
def test_mean_shift_bandwdith_5(self):
self._test_mean_shift(5.0)
if __name__ == "__main__":
unittest.main()