* solve error with degenerate trees
better error handling in case of exceptions during conversion
This commit is contained in:
Matteo Interlandi 2021-01-07 17:17:08 -08:00 коммит произвёл GitHub
Родитель 1af5f14834
Коммит 2c5b19f34e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 27 добавлений и 17 удалений

Просмотреть файл

@ -149,26 +149,22 @@ def convert(topology, backend, test_input, device, extra_config={}):
tvm_backend = tvm.__name__
for operator in topology.topological_operator_iterator():
try:
converter = get_converter(operator.type)
if backend == onnx.__name__:
# vers = LooseVersion(torch.__version__)
# allowed_min = LooseVersion("1.6.0")
# Pytorch <= 1.6.0 has a bug with exporting GEMM into ONNX.
# For the moment only tree_trav is enabled for pytorch <= 1.6.0
# if vers < allowed_min:
extra_config[constants.TREE_IMPLEMENTATION] = "tree_trav"
operator_map[operator.full_name] = converter(operator, device, extra_config)
except ValueError:
converter = get_converter(operator.type)
if convert is None:
raise MissingConverter(
"Unable to find converter for {} type {} with extra config: {}.".format(
operator.type, type(getattr(operator, "raw_model", None)), extra_config
)
)
except Exception as e:
raise e
if backend == onnx.__name__:
# vers = LooseVersion(torch.__version__)
# allowed_min = LooseVersion("1.6.0")
# Pytorch <= 1.6.0 has a bug with exporting GEMM into ONNX.
# For the moment only tree_trav is enabled for pytorch <= 1.6.0
# if vers < allowed_min:
extra_config[constants.TREE_IMPLEMENTATION] = "tree_trav"
operator_map[operator.full_name] = converter(operator, device, extra_config)
# Set the parameters for the model / container
n_threads = None if constants.N_THREADS not in extra_config else extra_config[constants.N_THREADS]

Просмотреть файл

@ -14,6 +14,7 @@ import numpy as np
from ._tree_implementations import TreeImpl
from ._tree_implementations import GEMMDecisionTreeImpl, TreeTraversalDecisionTreeImpl, PerfectTreeTraversalDecisionTreeImpl
from . import constants
from hummingbird.ml.exceptions import MissingConverter
class Node:
@ -144,7 +145,7 @@ def get_tree_implementation_by_config_or_depth(extra_config, max_depth, low=3, h
elif extra_config[constants.TREE_IMPLEMENTATION] == TreeImpl.perf_tree_trav.name:
return TreeImpl.perf_tree_trav
else:
raise ValueError("Tree implementation {} not found".format(extra_config))
raise MissingConverter("Tree implementation {} not found".format(extra_config))
def get_tree_params_and_type(tree_infos, get_tree_parameters, extra_config):
@ -205,7 +206,7 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val
features = [0, 0, 0]
thresholds = [0, 0, 0]
n_classes = values.shape[1] if type(values) is np.ndarray else 1
values = np.array([np.array([0.0]), values[0], values[0]])
values = np.array([np.zeros(n_classes), values[0], values[0]])
values.reshape(3, n_classes)
ids = [i for i in range(len(lefts))]

Просмотреть файл

@ -238,6 +238,19 @@ class TestSklearnTreeConverter(unittest.TestCase):
extra_config={constants.TREE_IMPLEMENTATION: "perf_tree_trav"}
)
# Another small tree tests
def test_random_forest_classifier_small_tree_converter(self):
seed = 0
np.random.seed(seed=0)
N = 9
X = np.random.randn(N, 8)
y = np.random.randint(low=0, high=2, size=N)
model = RandomForestClassifier(random_state=seed)
model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch")
self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
# Float 64 classification test helper
def _run_float64_tree_classification_converter(self, model_type, num_classes, extra_config={}, labels_shift=0, **kwargs):
warnings.filterwarnings("ignore")