131 строка
6.4 KiB
Python
131 строка
6.4 KiB
Python
"""
|
|
Tests Sklearn IsolationForest converter.
|
|
"""
|
|
import unittest
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import sys
|
|
import torch
|
|
from sklearn.ensemble import IsolationForest
|
|
|
|
import hummingbird.ml
|
|
from hummingbird.ml import constants
|
|
from hummingbird.ml._utils import onnx_runtime_installed, tvm_installed, is_on_github_actions
|
|
from tree_utils import iforest_implementation_map
|
|
|
|
|
|
class TestIsolationForestConverter(unittest.TestCase):
|
|
# Check tree implementation
|
|
def test_iforest_implementation(self):
|
|
warnings.filterwarnings("ignore")
|
|
np.random.seed(0)
|
|
X = np.random.rand(10, 1)
|
|
X = np.array(X, dtype=np.float32)
|
|
model = IsolationForest(n_estimators=1, max_samples=2)
|
|
for extra_config_param in ["tree_trav", "perf_tree_trav", "gemm"]:
|
|
model.fit(X)
|
|
torch_model = hummingbird.ml.convert(model, "torch", extra_config={"tree_implementation": extra_config_param})
|
|
self.assertIsNotNone(torch_model)
|
|
self.assertEqual(str(type(list(torch_model.model._operators)[0])), iforest_implementation_map[extra_config_param])
|
|
|
|
def _run_isolation_forest_converter(self, extra_config={}):
|
|
warnings.filterwarnings("ignore")
|
|
for max_samples in [2 ** 1, 2 ** 3, 2 ** 8, 2 ** 10, 2 ** 12]:
|
|
model = IsolationForest(n_estimators=10, max_samples=max_samples)
|
|
np.random.seed(0)
|
|
X = np.random.rand(100, 200)
|
|
X = np.array(X, dtype=np.float32)
|
|
model.fit(X)
|
|
torch_model = hummingbird.ml.convert(model, "torch", extra_config=extra_config)
|
|
self.assertIsNotNone(torch_model)
|
|
np.testing.assert_allclose(model.decision_function(X), torch_model.decision_function(X), rtol=1e-06, atol=1e-06)
|
|
np.testing.assert_allclose(model.score_samples(X), torch_model.score_samples(X), rtol=1e-06, atol=1e-06)
|
|
np.testing.assert_array_equal(model.predict(X), torch_model.predict(X))
|
|
|
|
# Isolation Forest
|
|
def test_isolation_forest_converter(self):
|
|
self._run_isolation_forest_converter()
|
|
|
|
# Gemm Isolation Forest
|
|
def test_isolation_forest_gemm_converter(self):
|
|
self._run_isolation_forest_converter(extra_config={"tree_implementation": "gemm"})
|
|
|
|
# Tree_trav Isolation Forest
|
|
def test_isolation_forest_tree_trav_converter(self):
|
|
self._run_isolation_forest_converter(extra_config={"tree_implementation": "tree_trav"})
|
|
|
|
# Perf_tree_trav Isolation Forest
|
|
def test_isolation_forest_perf_tree_trav_converter(self):
|
|
self._run_isolation_forest_converter(extra_config={"tree_implementation": "perf_tree_trav"})
|
|
|
|
# Float 64 data tests
|
|
def test_float64_isolation_forest_converter(self):
|
|
warnings.filterwarnings("ignore")
|
|
for max_samples in [2 ** 1, 2 ** 3, 2 ** 8, 2 ** 10, 2 ** 12]:
|
|
model = IsolationForest(n_estimators=10, max_samples=max_samples)
|
|
np.random.seed(0)
|
|
X = np.random.rand(100, 200)
|
|
model.fit(X)
|
|
torch_model = hummingbird.ml.convert(model, "torch", extra_config={})
|
|
self.assertIsNotNone(torch_model)
|
|
np.testing.assert_allclose(model.decision_function(X), torch_model.decision_function(X), rtol=1e-06, atol=1e-06)
|
|
np.testing.assert_allclose(model.score_samples(X), torch_model.score_samples(X), rtol=1e-06, atol=1e-06)
|
|
np.testing.assert_array_equal(model.predict(X), torch_model.predict(X))
|
|
|
|
# Test TorchScript backend.
|
|
def test_isolation_forest_ts_converter(self):
|
|
warnings.filterwarnings("ignore")
|
|
for max_samples in [2 ** 1, 2 ** 3, 2 ** 8, 2 ** 10, 2 ** 12]:
|
|
model = IsolationForest(n_estimators=10, max_samples=max_samples)
|
|
np.random.seed(0)
|
|
X = np.random.rand(100, 200)
|
|
X = np.array(X, dtype=np.float32)
|
|
model.fit(X)
|
|
torch_model = hummingbird.ml.convert(model, "torch.jit", X, extra_config={})
|
|
self.assertIsNotNone(torch_model)
|
|
np.testing.assert_allclose(model.decision_function(X), torch_model.decision_function(X), rtol=1e-06, atol=1e-06)
|
|
np.testing.assert_allclose(model.score_samples(X), torch_model.score_samples(X), rtol=1e-06, atol=1e-06)
|
|
np.testing.assert_array_equal(model.predict(X), torch_model.predict(X))
|
|
|
|
# Test ONNX backend.
|
|
@unittest.skipIf(not (onnx_runtime_installed()), reason="ONNX tests require ORT")
|
|
def test_isolation_forest_onnx_converter(self):
|
|
warnings.filterwarnings("ignore")
|
|
for max_samples in [2 ** 1, 2 ** 3, 2 ** 8, 2 ** 10, 2 ** 12]:
|
|
model = IsolationForest(n_estimators=10, max_samples=max_samples)
|
|
np.random.seed(0)
|
|
X = np.random.rand(100, 200)
|
|
X = np.array(X, dtype=np.float32)
|
|
model.fit(X)
|
|
onnx_model = hummingbird.ml.convert(model, "onnx", X, extra_config={})
|
|
self.assertIsNotNone(onnx_model)
|
|
np.testing.assert_allclose(model.decision_function(X), onnx_model.decision_function(X), rtol=1e-06, atol=1e-06)
|
|
np.testing.assert_allclose(model.score_samples(X), onnx_model.score_samples(X), rtol=1e-06, atol=1e-06)
|
|
np.testing.assert_array_equal(model.predict(X), onnx_model.predict(X))
|
|
|
|
# Test TVM backend.
|
|
@unittest.skipIf(not (tvm_installed()), reason="TVM test requires TVM")
|
|
@unittest.skipIf(
|
|
((sys.platform == "linux") and is_on_github_actions()),
|
|
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
|
|
)
|
|
def test_isolation_forest_tvm_converter(self):
|
|
warnings.filterwarnings("ignore")
|
|
for max_samples in [2 ** 1, 2 ** 3, 2 ** 8, 2 ** 10, 2 ** 12]:
|
|
model = IsolationForest(n_estimators=10, max_samples=max_samples)
|
|
np.random.seed(0)
|
|
X = np.random.rand(100, 200)
|
|
X = np.array(X, dtype=np.float32)
|
|
model.fit(X)
|
|
hb_model = hummingbird.ml.convert(model, "tvm", X, extra_config={constants.TVM_MAX_FUSE_DEPTH: 30})
|
|
|
|
self.assertIsNotNone(hb_model)
|
|
np.testing.assert_allclose(model.decision_function(X), hb_model.decision_function(X), rtol=1e-06, atol=1e-06)
|
|
np.testing.assert_allclose(model.score_samples(X), hb_model.score_samples(X), rtol=1e-06, atol=1e-06)
|
|
np.testing.assert_array_equal(model.predict(X), hb_model.predict(X))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|