Fix bug with degenerate trees (#426)
* solve error with degenerate trees better error handling in case of exceptions during conversion
This commit is contained in:
Родитель
1af5f14834
Коммит
2c5b19f34e
|
@ -149,26 +149,22 @@ def convert(topology, backend, test_input, device, extra_config={}):
|
|||
tvm_backend = tvm.__name__
|
||||
|
||||
for operator in topology.topological_operator_iterator():
|
||||
try:
|
||||
converter = get_converter(operator.type)
|
||||
|
||||
if backend == onnx.__name__:
|
||||
# vers = LooseVersion(torch.__version__)
|
||||
# allowed_min = LooseVersion("1.6.0")
|
||||
# Pytorch <= 1.6.0 has a bug with exporting GEMM into ONNX.
|
||||
# For the moment only tree_trav is enabled for pytorch <= 1.6.0
|
||||
# if vers < allowed_min:
|
||||
extra_config[constants.TREE_IMPLEMENTATION] = "tree_trav"
|
||||
|
||||
operator_map[operator.full_name] = converter(operator, device, extra_config)
|
||||
except ValueError:
|
||||
converter = get_converter(operator.type)
|
||||
if convert is None:
|
||||
raise MissingConverter(
|
||||
"Unable to find converter for {} type {} with extra config: {}.".format(
|
||||
operator.type, type(getattr(operator, "raw_model", None)), extra_config
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if backend == onnx.__name__:
|
||||
# vers = LooseVersion(torch.__version__)
|
||||
# allowed_min = LooseVersion("1.6.0")
|
||||
# Pytorch <= 1.6.0 has a bug with exporting GEMM into ONNX.
|
||||
# For the moment only tree_trav is enabled for pytorch <= 1.6.0
|
||||
# if vers < allowed_min:
|
||||
extra_config[constants.TREE_IMPLEMENTATION] = "tree_trav"
|
||||
operator_map[operator.full_name] = converter(operator, device, extra_config)
|
||||
|
||||
# Set the parameters for the model / container
|
||||
n_threads = None if constants.N_THREADS not in extra_config else extra_config[constants.N_THREADS]
|
||||
|
|
|
@ -14,6 +14,7 @@ import numpy as np
|
|||
from ._tree_implementations import TreeImpl
|
||||
from ._tree_implementations import GEMMDecisionTreeImpl, TreeTraversalDecisionTreeImpl, PerfectTreeTraversalDecisionTreeImpl
|
||||
from . import constants
|
||||
from hummingbird.ml.exceptions import MissingConverter
|
||||
|
||||
|
||||
class Node:
|
||||
|
@ -144,7 +145,7 @@ def get_tree_implementation_by_config_or_depth(extra_config, max_depth, low=3, h
|
|||
elif extra_config[constants.TREE_IMPLEMENTATION] == TreeImpl.perf_tree_trav.name:
|
||||
return TreeImpl.perf_tree_trav
|
||||
else:
|
||||
raise ValueError("Tree implementation {} not found".format(extra_config))
|
||||
raise MissingConverter("Tree implementation {} not found".format(extra_config))
|
||||
|
||||
|
||||
def get_tree_params_and_type(tree_infos, get_tree_parameters, extra_config):
|
||||
|
@ -205,7 +206,7 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val
|
|||
features = [0, 0, 0]
|
||||
thresholds = [0, 0, 0]
|
||||
n_classes = values.shape[1] if type(values) is np.ndarray else 1
|
||||
values = np.array([np.array([0.0]), values[0], values[0]])
|
||||
values = np.array([np.zeros(n_classes), values[0], values[0]])
|
||||
values.reshape(3, n_classes)
|
||||
|
||||
ids = [i for i in range(len(lefts))]
|
||||
|
|
|
@ -238,6 +238,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}
|
||||
)
|
||||
|
||||
# Another small tree tests
|
||||
def test_random_forest_classifier_small_tree_converter(self):
|
||||
seed = 0
|
||||
np.random.seed(seed=0)
|
||||
N = 9
|
||||
X = np.random.randn(N, 8)
|
||||
y = np.random.randint(low=0, high=2, size=N)
|
||||
model = RandomForestClassifier(random_state=seed)
|
||||
model.fit(X, y)
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Float 64 classification test helper
|
||||
def _run_float64_tree_classification_converter(self, model_type, num_classes, extra_config={}, labels_shift=0, **kwargs):
|
||||
warnings.filterwarnings("ignore")
|
||||
|
|
Загрузка…
Ссылка в новой задаче