2020-04-29 19:09:41 +03:00
|
|
|
"""
|
|
|
|
Tests Sklearn RandomForest, DecisionTree, ExtraTrees converters.
|
|
|
|
"""
|
|
|
|
import unittest
|
|
|
|
import warnings
|
2024-02-13 22:53:20 +03:00
|
|
|
import sys
|
2020-04-29 19:09:41 +03:00
|
|
|
|
|
|
|
import numpy as np
|
2020-06-04 08:33:02 +03:00
|
|
|
from sklearn.ensemble import ExtraTreesClassifier, ExtraTreesRegressor, RandomForestClassifier, RandomForestRegressor
|
|
|
|
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
2021-01-14 01:27:09 +03:00
|
|
|
from sklearn import datasets
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-05-06 02:34:12 +03:00
|
|
|
import hummingbird.ml
|
|
|
|
from hummingbird.ml.exceptions import MissingConverter
|
2024-02-13 22:53:20 +03:00
|
|
|
from hummingbird.ml._utils import tvm_installed, is_on_github_actions
|
2020-11-04 00:21:02 +03:00
|
|
|
from hummingbird.ml import constants
|
2020-04-29 19:09:41 +03:00
|
|
|
from tree_utils import dt_implementation_map
|
|
|
|
|
2021-06-15 23:23:43 +03:00
|
|
|
import random
|
|
|
|
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-04 08:33:02 +03:00
|
|
|
class TestSklearnTreeConverter(unittest.TestCase):
|
2020-04-29 19:09:41 +03:00
|
|
|
# Check tree implementation
|
|
|
|
def test_random_forest_implementation(self):
|
|
|
|
warnings.filterwarnings("ignore")
|
2020-06-13 00:34:36 +03:00
|
|
|
np.random.seed(0)
|
2020-04-29 19:09:41 +03:00
|
|
|
X = np.random.rand(1, 1)
|
|
|
|
X = np.array(X, dtype=np.float32)
|
|
|
|
y = np.random.randint(2, size=1)
|
|
|
|
|
|
|
|
for model in [RandomForestClassifier(n_estimators=1, max_depth=1), RandomForestRegressor(n_estimators=1, max_depth=1)]:
|
|
|
|
for extra_config_param in ["tree_trav", "perf_tree_trav", "gemm"]:
|
|
|
|
model.fit(X, y)
|
|
|
|
|
2020-11-04 00:21:02 +03:00
|
|
|
torch_model = hummingbird.ml.convert(
|
|
|
|
model, "torch", extra_config={constants.TREE_IMPLEMENTATION: extra_config_param}
|
|
|
|
)
|
2020-06-19 01:12:24 +03:00
|
|
|
self.assertIsNotNone(torch_model)
|
2021-01-12 20:47:28 +03:00
|
|
|
self.assertTrue(str(type(list(torch_model.model._operators)[0])) == dt_implementation_map[extra_config_param])
|
2020-05-01 20:01:51 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Used for classification tests
|
2020-08-14 00:03:09 +03:00
|
|
|
def _run_tree_classification_converter(
|
|
|
|
self, model_type, num_classes, backend="torch", extra_config={}, labels_shift=0, **kwargs
|
|
|
|
):
|
2020-04-29 19:09:41 +03:00
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
for max_depth in [1, 3, 8, 10, 12, None]:
|
2020-06-13 00:34:36 +03:00
|
|
|
np.random.seed(0)
|
2020-04-29 19:09:41 +03:00
|
|
|
X = np.random.rand(100, 200)
|
|
|
|
X = np.array(X, dtype=np.float32)
|
|
|
|
y = np.random.randint(num_classes, size=100) + labels_shift
|
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
model = model_type(max_depth=max_depth, **kwargs)
|
2020-04-29 19:09:41 +03:00
|
|
|
model.fit(X, y)
|
2020-08-14 00:03:09 +03:00
|
|
|
torch_model = hummingbird.ml.convert(model, backend, X, extra_config=extra_config)
|
2020-06-19 01:12:24 +03:00
|
|
|
self.assertIsNotNone(torch_model)
|
|
|
|
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2022-06-21 05:05:25 +03:00
|
|
|
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
|
2021-03-19 19:17:15 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Random forest binary classifier
|
2020-04-29 19:09:41 +03:00
|
|
|
def test_random_forest_classifier_binary_converter(self):
|
2020-06-06 01:34:42 +03:00
|
|
|
self._run_tree_classification_converter(RandomForestClassifier, 2, n_estimators=10)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Random forest gemm classifier
|
2020-04-29 19:09:41 +03:00
|
|
|
def test_random_forest_gemm_classifier_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier, 2, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Random forest tree_trav classifier
|
2020-04-29 19:09:41 +03:00
|
|
|
def test_random_forest_tree_trav_classifier_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier, 2, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Random forest perf_tree_trav classifier
|
2020-04-29 19:09:41 +03:00
|
|
|
def test_random_forest_perf_tree_trav_classifier_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier, 2, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Random forest multi classifier
|
2020-04-29 19:09:41 +03:00
|
|
|
def test_random_forest_multi_classifier_converter(self):
|
2020-06-06 01:34:42 +03:00
|
|
|
self._run_tree_classification_converter(RandomForestClassifier, 3, n_estimators=10)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Random forest gemm multi classifier
|
2020-04-29 19:09:41 +03:00
|
|
|
def test_random_forest_gemm_multi_classifier_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier, 3, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Random forest tree_trav multi classifier
|
2020-04-29 19:09:41 +03:00
|
|
|
def test_random_forest_tree_trav_multi_classifier_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier, 3, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Random forest perf_tree_trav multi classifier
|
2020-04-29 19:09:41 +03:00
|
|
|
def test_random_forest_perf_tree_trav_multi_classifier_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier, 3, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-06-06 01:34:42 +03:00
|
|
|
|
|
|
|
# Random forest gemm classifier shifted classes
|
|
|
|
def test_random_forest_gemm_classifier_shifted_labels_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier, 3, labels_shift=2, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-06-06 01:34:42 +03:00
|
|
|
|
|
|
|
# Random forest tree_trav classifier shifted classes
|
|
|
|
def test_random_forest_tree_trav_classifier_shifted_labels_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier,
|
|
|
|
3,
|
|
|
|
labels_shift=2,
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"},
|
|
|
|
n_estimators=10,
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Random forest perf_tree_trav classifier shifted classes
|
|
|
|
def test_random_forest_perf_tree_trav_classifier_shifted_labels_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier,
|
|
|
|
3,
|
|
|
|
labels_shift=2,
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
|
|
|
|
n_estimators=10,
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-04 08:33:02 +03:00
|
|
|
# Used for regression tests
|
2020-08-14 00:03:09 +03:00
|
|
|
def _run_tree_regressor_converter(self, model_type, num_classes, backend="torch", extra_config={}, **kwargs):
|
2020-04-29 19:09:41 +03:00
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
for max_depth in [1, 3, 8, 10, 12, None]:
|
2020-06-04 08:33:02 +03:00
|
|
|
model = model_type(max_depth=max_depth, **kwargs)
|
2020-06-13 00:34:36 +03:00
|
|
|
np.random.seed(0)
|
2020-04-29 19:09:41 +03:00
|
|
|
X = np.random.rand(100, 200)
|
|
|
|
X = np.array(X, dtype=np.float32)
|
|
|
|
y = np.random.randint(num_classes, size=100)
|
|
|
|
|
|
|
|
model.fit(X, y)
|
2020-08-14 00:03:09 +03:00
|
|
|
torch_model = hummingbird.ml.convert(model, backend, X, extra_config=extra_config)
|
2020-06-19 01:12:24 +03:00
|
|
|
self.assertIsNotNone(torch_model)
|
|
|
|
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-04 08:33:02 +03:00
|
|
|
# Random forest regressor
|
2020-04-29 19:09:41 +03:00
|
|
|
def test_random_forest_regressor_converter(self):
|
2020-06-04 08:33:02 +03:00
|
|
|
self._run_tree_regressor_converter(RandomForestRegressor, 1000, n_estimators=10)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-04 08:33:02 +03:00
|
|
|
# Random forest gemm regressor
|
2020-04-29 19:09:41 +03:00
|
|
|
def test_random_forest_gemm_regressor_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-04 08:33:02 +03:00
|
|
|
# Random forest tree_trav regressor
|
2020-04-29 19:09:41 +03:00
|
|
|
def test_random_forest_tree_trav_regressor_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-04 08:33:02 +03:00
|
|
|
# Random forest perf_tree_trav regressor
|
2020-04-29 19:09:41 +03:00
|
|
|
def test_random_forest_perf_tree_trav_regressor_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-04 08:33:02 +03:00
|
|
|
# Extra trees regressor
|
2020-06-02 19:37:30 +03:00
|
|
|
def test_extra_trees_regressor_converter(self):
|
2020-06-04 08:33:02 +03:00
|
|
|
self._run_tree_regressor_converter(ExtraTreesRegressor, 1000, n_estimators=10)
|
2020-06-02 19:37:30 +03:00
|
|
|
|
2020-06-04 08:33:02 +03:00
|
|
|
# Extra trees gemm regressor
|
2020-06-02 19:37:30 +03:00
|
|
|
def test_extra_trees_gemm_regressor_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
ExtraTreesRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-06-02 19:37:30 +03:00
|
|
|
|
2020-06-04 08:33:02 +03:00
|
|
|
# Extra trees tree_trav regressor
|
2020-06-02 19:37:30 +03:00
|
|
|
def test_extra_trees_tree_trav_regressor_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
ExtraTreesRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-06-02 19:37:30 +03:00
|
|
|
|
2020-06-04 08:33:02 +03:00
|
|
|
# Extra trees perf_tree_trav regressor
|
2020-06-02 19:37:30 +03:00
|
|
|
def test_extra_trees_perf_tree_trav_regressor_converter(self):
|
2020-06-13 00:34:36 +03:00
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
ExtraTreesRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}, n_estimators=10
|
2020-06-13 00:34:36 +03:00
|
|
|
)
|
2020-06-04 08:33:02 +03:00
|
|
|
|
|
|
|
# Decision tree regressor
|
|
|
|
def test_decision_tree_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000)
|
|
|
|
|
|
|
|
# Decision tree gemm regressor
|
|
|
|
def test_decision_tree_gemm_regressor_converter(self):
|
2020-11-04 00:21:02 +03:00
|
|
|
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "gemm"})
|
2020-06-04 08:33:02 +03:00
|
|
|
|
|
|
|
# Decision tree tree_trav regressor
|
|
|
|
def test_decision_tree_tree_trav_regressor_converter(self):
|
2020-11-04 00:21:02 +03:00
|
|
|
self._run_tree_regressor_converter(
|
|
|
|
DecisionTreeRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}
|
|
|
|
)
|
2020-06-04 08:33:02 +03:00
|
|
|
|
|
|
|
# Decision tree perf_tree_trav regressor
|
|
|
|
def test_decision_tree_perf_tree_trav_regressor_converter(self):
|
2020-11-04 00:21:02 +03:00
|
|
|
self._run_tree_regressor_converter(
|
|
|
|
DecisionTreeRegressor, 1000, extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}
|
|
|
|
)
|
2020-06-02 19:37:30 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Decision tree classifier
|
|
|
|
def test_decision_tree_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(DecisionTreeClassifier, 3)
|
|
|
|
|
|
|
|
# Extra trees classifier
|
|
|
|
def test_extra_trees_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(ExtraTreesClassifier, 3, n_estimators=10)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Used for small tree tests
|
|
|
|
def _run_random_forest_classifier_single_node_tree_converter(self, extra_config={}):
|
|
|
|
warnings.filterwarnings("ignore")
|
2020-06-13 00:34:36 +03:00
|
|
|
np.random.seed(0)
|
2020-06-06 01:34:42 +03:00
|
|
|
X = np.random.rand(1, 1)
|
|
|
|
X = np.array(X, dtype=np.float32)
|
|
|
|
y = np.random.randint(1, size=1)
|
|
|
|
model = RandomForestClassifier(n_estimators=1).fit(X, y)
|
2020-06-19 01:12:24 +03:00
|
|
|
torch_model = hummingbird.ml.convert(model, "torch", extra_config=extra_config)
|
|
|
|
self.assertIsNotNone(torch_model)
|
|
|
|
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Small tree gemm implementation
|
|
|
|
def test_random_forest_gemm_classifier_single_node_tree_converter(self):
|
2020-11-04 00:21:02 +03:00
|
|
|
self._run_random_forest_classifier_single_node_tree_converter(extra_config={constants.TREE_IMPLEMENTATION: "gemm"})
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Small tree tree_trav implementation
|
|
|
|
def test_random_forest_tree_trav_classifier_single_node_tree_converter(self):
|
2020-11-04 00:21:02 +03:00
|
|
|
self._run_random_forest_classifier_single_node_tree_converter(
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}
|
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-06-06 01:34:42 +03:00
|
|
|
# Small tree perf_tree_trav implementation
|
|
|
|
def test_random_forest_perf_tree_trav_classifier_single_node_tree_converter(self):
|
2020-11-04 00:21:02 +03:00
|
|
|
self._run_random_forest_classifier_single_node_tree_converter(
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}
|
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2021-01-08 04:17:08 +03:00
|
|
|
# Another small tree tests
|
|
|
|
def test_random_forest_classifier_small_tree_converter(self):
|
|
|
|
seed = 0
|
|
|
|
np.random.seed(seed=0)
|
|
|
|
N = 9
|
|
|
|
X = np.random.randn(N, 8)
|
|
|
|
y = np.random.randint(low=0, high=2, size=N)
|
|
|
|
model = RandomForestClassifier(random_state=seed)
|
|
|
|
model.fit(X, y)
|
|
|
|
torch_model = hummingbird.ml.convert(model, "torch")
|
|
|
|
self.assertIsNotNone(torch_model)
|
|
|
|
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
|
|
|
|
2020-07-17 20:16:27 +03:00
|
|
|
# Float 64 classification test helper
|
|
|
|
def _run_float64_tree_classification_converter(self, model_type, num_classes, extra_config={}, labels_shift=0, **kwargs):
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
for max_depth in [1, 3, 8, 10, 12, None]:
|
|
|
|
np.random.seed(0)
|
|
|
|
X = np.random.rand(100, 200)
|
|
|
|
y = np.random.randint(num_classes, size=100) + labels_shift
|
|
|
|
|
|
|
|
model = model_type(max_depth=max_depth, **kwargs)
|
|
|
|
model.fit(X, y)
|
|
|
|
torch_model = hummingbird.ml.convert(model, "torch", extra_config=extra_config)
|
|
|
|
self.assertIsNotNone(torch_model)
|
|
|
|
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
|
|
|
|
|
|
|
# Random forest binary classifier (float64 data)
|
|
|
|
def test_float64_random_forest_classifier_binary_converter(self):
|
|
|
|
self._run_float64_tree_classification_converter(RandomForestClassifier, 2, n_estimators=10)
|
|
|
|
|
|
|
|
# Decision tree classifier (float64 data)
|
|
|
|
def test_float64_decision_tree_classifier_converter(self):
|
|
|
|
self._run_float64_tree_classification_converter(DecisionTreeClassifier, 3)
|
|
|
|
|
|
|
|
# Extra trees classifier (float64 data)
|
|
|
|
def test_float64_extra_trees_classifier_converter(self):
|
|
|
|
self._run_float64_tree_classification_converter(ExtraTreesClassifier, 3, n_estimators=10)
|
|
|
|
|
|
|
|
# Float 64 regression tests helper
|
|
|
|
def _run_float64_tree_regressor_converter(self, model_type, num_classes, extra_config={}, **kwargs):
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
for max_depth in [1, 3, 8, 10, 12, None]:
|
|
|
|
model = model_type(max_depth=max_depth, **kwargs)
|
|
|
|
np.random.seed(0)
|
|
|
|
X = np.random.rand(100, 200)
|
|
|
|
y = np.random.randint(num_classes, size=100)
|
|
|
|
|
|
|
|
model.fit(X, y)
|
|
|
|
torch_model = hummingbird.ml.convert(model, "torch", extra_config=extra_config)
|
|
|
|
self.assertIsNotNone(torch_model)
|
|
|
|
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
|
|
|
|
|
|
|
|
# Random forest regressor (float64 data)
|
|
|
|
def test_float64_random_forest_regressor_converter(self):
|
|
|
|
self._run_float64_tree_regressor_converter(RandomForestRegressor, 1000, n_estimators=10)
|
|
|
|
|
|
|
|
# Decision tree regressor (float64 data)
|
|
|
|
def test_float64_decision_tree_regressor_converter(self):
|
|
|
|
self._run_float64_tree_regressor_converter(DecisionTreeRegressor, 1000)
|
|
|
|
|
|
|
|
# Extra trees regressor (float64 data)
|
|
|
|
def test_float64_extra_trees_regressor_converter(self):
|
|
|
|
self._run_float64_tree_regressor_converter(ExtraTreesRegressor, 1000, n_estimators=10)
|
|
|
|
|
2020-04-29 19:09:41 +03:00
|
|
|
# Failure Cases
|
|
|
|
def test_random_forest_classifier_raises_wrong_type(self):
|
|
|
|
warnings.filterwarnings("ignore")
|
2020-06-13 00:34:36 +03:00
|
|
|
np.random.seed(0)
|
2020-04-29 19:09:41 +03:00
|
|
|
X = np.random.rand(100, 200)
|
|
|
|
X = np.array(X, dtype=np.float32)
|
|
|
|
y = np.random.randint(3, size=100).astype(np.float32) # y must be int, not float, should error
|
|
|
|
model = RandomForestClassifier(n_estimators=10).fit(X, y)
|
2020-06-19 01:12:24 +03:00
|
|
|
self.assertRaises(RuntimeError, hummingbird.ml.convert, model, "torch")
|
2020-04-29 19:09:41 +03:00
|
|
|
|
|
|
|
def test_random_forest_classifier_raises_wrong_extra_config(self):
|
|
|
|
warnings.filterwarnings("ignore")
|
2020-06-13 00:34:36 +03:00
|
|
|
np.random.seed(0)
|
2020-04-29 19:09:41 +03:00
|
|
|
X = np.array(np.random.rand(100, 200), dtype=np.float32)
|
|
|
|
y = np.random.randint(3, size=100)
|
|
|
|
model = RandomForestClassifier(n_estimators=10).fit(X, y)
|
2020-05-27 00:44:40 +03:00
|
|
|
self.assertRaises(
|
2020-11-04 00:21:02 +03:00
|
|
|
MissingConverter, hummingbird.ml.convert, model, "torch", extra_config={constants.TREE_IMPLEMENTATION: "nonsense"}
|
2020-05-27 00:44:40 +03:00
|
|
|
)
|
2020-04-29 19:09:41 +03:00
|
|
|
|
2020-08-14 00:03:09 +03:00
|
|
|
# Test trees with TorchScript backend
|
|
|
|
# Random forest binary classifier
|
|
|
|
def test_random_forest_ts_classifier_binary_converter(self):
|
|
|
|
self._run_tree_classification_converter(RandomForestClassifier, 2, "torch.jit", n_estimators=10)
|
|
|
|
|
|
|
|
# Random forest gemm classifier
|
|
|
|
def test_random_forest_ts_gemm_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier, 2, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest tree_trav classifier
|
|
|
|
def test_random_forest_ts_tree_trav_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier, 2, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest perf_tree_trav classifier
|
|
|
|
def test_random_forest_ts_perf_tree_trav_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier,
|
|
|
|
2,
|
|
|
|
"torch.jit",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
|
|
|
|
n_estimators=10,
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest multi classifier
|
|
|
|
def test_random_forest_ts_multi_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(RandomForestClassifier, 3, "torch.jit", n_estimators=10)
|
|
|
|
|
|
|
|
# Random forest gemm multi classifier
|
|
|
|
def test_random_forest_ts_gemm_multi_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier, 3, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest tree_trav multi classifier
|
|
|
|
def test_random_forest_ts_tree_trav_multi_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier, 3, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest perf_tree_trav multi classifier
|
|
|
|
def test_random_forest_ts_perf_tree_trav_multi_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestClassifier,
|
|
|
|
3,
|
|
|
|
"torch.jit",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
|
|
|
|
n_estimators=10,
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest gemm classifier shifted classes
|
|
|
|
def test_random_forest_ts_gemm_classifier_shifted_labels_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
RandomForestClassifier,
|
|
|
|
3,
|
|
|
|
"torch.jit",
|
|
|
|
labels_shift=2,
|
2020-11-04 00:21:02 +03:00
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "gemm"},
|
2020-08-14 00:03:09 +03:00
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest tree_trav classifier shifted classes
|
|
|
|
def test_random_forest_ts_tree_trav_classifier_shifted_labels_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
RandomForestClassifier,
|
|
|
|
3,
|
|
|
|
"torch.jit",
|
|
|
|
labels_shift=2,
|
2020-11-04 00:21:02 +03:00
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"},
|
2020-08-14 00:03:09 +03:00
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest perf_tree_trav classifier shifted classes
|
|
|
|
def test_random_forest_ts_perf_tree_trav_classifier_shifted_labels_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
RandomForestClassifier,
|
|
|
|
3,
|
|
|
|
"torch.jit",
|
|
|
|
labels_shift=2,
|
2020-11-04 00:21:02 +03:00
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
|
2020-08-14 00:03:09 +03:00
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest regressor
|
|
|
|
def test_random_forest_ts_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(RandomForestRegressor, 1000, "torch.jit", n_estimators=10)
|
|
|
|
|
|
|
|
# Random forest gemm regressor
|
|
|
|
def test_random_forest_ts_gemm_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest tree_trav regressor
|
|
|
|
def test_random_forest_ts_tree_trav_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestRegressor,
|
|
|
|
1000,
|
|
|
|
"torch.jit",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"},
|
|
|
|
n_estimators=10,
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest perf_tree_trav regressor
|
|
|
|
def test_random_forest_ts_perf_tree_trav_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
RandomForestRegressor,
|
|
|
|
1000,
|
|
|
|
"torch.jit",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
|
|
|
|
n_estimators=10,
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Extra trees regressor
|
|
|
|
def test_extra_trees_ts_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(ExtraTreesRegressor, 1000, "torch.jit", n_estimators=10)
|
|
|
|
|
|
|
|
# Extra trees gemm regressor
|
|
|
|
def test_extra_trees_ts_gemm_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
ExtraTreesRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}, n_estimators=10
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Extra trees tree_trav regressor
|
|
|
|
def test_extra_trees_ts_tree_trav_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
ExtraTreesRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}, n_estimators=10
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Extra trees perf_tree_trav regressor
|
|
|
|
def test_extra_trees_ts_perf_tree_trav_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
ExtraTreesRegressor,
|
|
|
|
1000,
|
|
|
|
"torch.jit",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"},
|
|
|
|
n_estimators=10,
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Decision tree regressor
|
|
|
|
def test_decision_tree_ts_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, "torch.jit")
|
|
|
|
|
|
|
|
# Decision tree gemm regressor
|
|
|
|
def test_decision_tree_ts_gemm_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
DecisionTreeRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "gemm"}
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Decision tree tree_trav regressor
|
|
|
|
def test_decision_tree_ts_tree_trav_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
DecisionTreeRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "tree_trav"}
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Decision tree perf_tree_trav regressor
|
|
|
|
def test_decision_tree_ts_perf_tree_trav_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
2020-11-04 00:21:02 +03:00
|
|
|
DecisionTreeRegressor, 1000, "torch.jit", extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}
|
2020-08-14 00:03:09 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# Decision tree classifier
|
|
|
|
def test_decision_tree_ts_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
DecisionTreeClassifier, 3, "torch.jit",
|
|
|
|
)
|
|
|
|
|
|
|
|
# Extra trees classifier
|
|
|
|
def test_extra_trees_ts_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(ExtraTreesClassifier, 3, "torch.jit", n_estimators=10)
|
|
|
|
|
2020-11-04 00:21:02 +03:00
|
|
|
# Test trees with TVM backend
|
|
|
|
# Random forest gemm classifier
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_random_forest_tvm_gemm_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
RandomForestClassifier,
|
|
|
|
2,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest tree_trav classifier
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_random_forest_tvm_tree_trav_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
RandomForestClassifier,
|
|
|
|
2,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest perf_tree_trav classifier
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_random_forest_tvm_perf_tree_trav_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
RandomForestClassifier,
|
|
|
|
2,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest gemm multi classifier
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_random_forest_tvm_gemm_multi_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
RandomForestClassifier,
|
|
|
|
3,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest tree_trav multi classifier
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_random_forest_tvm_tree_trav_multi_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
RandomForestClassifier,
|
|
|
|
3,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest perf_tree_trav multi classifier
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_random_forest_tvm_perf_tree_trav_multi_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
RandomForestClassifier,
|
|
|
|
3,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest gemm classifier shifted classes
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_random_forest_tvm_gemm_classifier_shifted_labels_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
RandomForestClassifier,
|
|
|
|
3,
|
|
|
|
"tvm",
|
|
|
|
labels_shift=2,
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest tree_trav classifier shifted classes
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_random_forest_tvm_tree_trav_classifier_shifted_labels_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
RandomForestClassifier,
|
|
|
|
3,
|
|
|
|
"tvm",
|
|
|
|
labels_shift=2,
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest perf_tree_trav classifier shifted classes
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_random_forest_tvm_perf_tree_trav_classifier_shifted_labels_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
RandomForestClassifier,
|
|
|
|
3,
|
|
|
|
"tvm",
|
|
|
|
labels_shift=2,
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 10},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest gemm regressor
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_random_forest_tvm_gemm_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
|
|
|
RandomForestRegressor,
|
|
|
|
1000,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest tree_trav regressor
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_random_forest_tvm_tree_trav_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
|
|
|
RandomForestRegressor,
|
|
|
|
1000,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Random forest perf_tree_trav regressor
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_random_forest_tvm_perf_tree_trav_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
|
|
|
RandomForestRegressor,
|
|
|
|
1000,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 10},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Extra trees gemm regressor
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_extra_trees_tvm_gemm_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
|
|
|
ExtraTreesRegressor,
|
|
|
|
1000,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Extra trees tree_trav regressor
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_extra_trees_tvm_tree_trav_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
|
|
|
ExtraTreesRegressor,
|
|
|
|
1000,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Extra trees perf_tree_trav regressor
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_extra_trees_tvm_perf_tree_trav_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
|
|
|
ExtraTreesRegressor,
|
|
|
|
1000,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 10},
|
|
|
|
n_estimators=10,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Decision tree regressor
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_decision_tree_tvm_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(DecisionTreeRegressor, 1000, "tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
|
|
|
|
|
|
|
# Decision tree gemm regressor
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_decision_tree_tvm_gemm_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
|
|
|
DecisionTreeRegressor,
|
|
|
|
1000,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "gemm", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
)
|
|
|
|
|
|
|
|
# Decision tree tree_trav regressor
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_decision_tree_tvm_tree_trav_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
|
|
|
DecisionTreeRegressor,
|
|
|
|
1000,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "tree_trav", constants.TVM_MAX_FUSE_DEPTH: 30},
|
|
|
|
)
|
|
|
|
|
|
|
|
# Decision tree perf_tree_trav regressor
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_decision_tree_tvm_perf_tree_trav_regressor_converter(self):
|
|
|
|
self._run_tree_regressor_converter(
|
|
|
|
DecisionTreeRegressor,
|
|
|
|
1000,
|
|
|
|
"tvm",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav", constants.TVM_MAX_FUSE_DEPTH: 10},
|
|
|
|
)
|
|
|
|
|
|
|
|
# Decision tree classifier
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_decision_tree_tvm_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
DecisionTreeClassifier, 3, "tvm", extra_config={constants.TVM_MAX_FUSE_DEPTH: 30}
|
|
|
|
)
|
|
|
|
|
|
|
|
# Extra trees classifier
|
|
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM")
|
2024-02-13 22:53:20 +03:00
|
|
|
@unittest.skipIf(
|
|
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
|
|
)
|
2020-11-04 00:21:02 +03:00
|
|
|
def test_extra_trees_tvm_classifier_converter(self):
|
|
|
|
self._run_tree_classification_converter(
|
|
|
|
ExtraTreesClassifier, 3, "tvm", n_estimators=10, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30}
|
|
|
|
)
|
|
|
|
|
2021-01-14 01:27:09 +03:00
|
|
|
# TreeRegressor multioutput regression
|
|
|
|
def test_tree_regressors_multioutput_regression(self):
|
2021-04-05 22:00:19 +03:00
|
|
|
for tree_method in ["gemm", "tree_trav", "perf_tree_trav"]:
|
2021-01-14 01:27:09 +03:00
|
|
|
for n_targets in [1, 2, 7]:
|
|
|
|
for tree_class in [DecisionTreeRegressor, ExtraTreesRegressor, RandomForestRegressor]:
|
2021-06-15 23:23:43 +03:00
|
|
|
seed = random.randint(0, 2**32 - 1)
|
|
|
|
if tree_method == "perf_tree_trav":
|
|
|
|
model = tree_class(random_state=seed, max_depth=10)
|
|
|
|
else:
|
|
|
|
model = tree_class(random_state=seed)
|
2021-04-05 22:00:19 +03:00
|
|
|
X, y = datasets.make_regression(
|
2021-06-15 23:23:43 +03:00
|
|
|
n_samples=100, n_features=10, n_informative=5, n_targets=n_targets, random_state=seed
|
2021-04-05 22:00:19 +03:00
|
|
|
)
|
2021-01-14 01:27:09 +03:00
|
|
|
model.fit(X, y)
|
2021-06-15 23:23:43 +03:00
|
|
|
X = X.astype('float32')
|
|
|
|
y = y.astype('float32')
|
2021-01-14 01:27:09 +03:00
|
|
|
|
2021-04-05 22:00:19 +03:00
|
|
|
torch_model = hummingbird.ml.convert(
|
2021-06-15 23:23:43 +03:00
|
|
|
model,
|
|
|
|
"torch",
|
|
|
|
extra_config={constants.TREE_IMPLEMENTATION: tree_method, constants.TREE_OP_PRECISION_DTYPE: "float64"}
|
2021-04-05 22:00:19 +03:00
|
|
|
)
|
2021-01-14 01:27:09 +03:00
|
|
|
self.assertTrue(torch_model is not None)
|
2021-06-15 23:23:43 +03:00
|
|
|
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-5, atol=1e-5, err_msg="{}/{}/{}/{}".format(tree_method, n_targets, tree_class, seed))
|
2021-01-14 01:27:09 +03:00
|
|
|
|
2020-04-29 19:09:41 +03:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|