From c525c51060a58cb848a4f866b2d844963ada1f94 Mon Sep 17 00:00:00 2001 From: Matteo Interlandi Date: Tue, 2 Feb 2021 17:11:00 -0800 Subject: [PATCH] Revert #419 (#438) --- .../ml/operator_converters/_gbdt_commons.py | 18 +- .../ml/operator_converters/_tree_commons.py | 70 ++---- .../_tree_implementations.py | 207 +++++++----------- .../ml/operator_converters/lightgbm.py | 28 +-- .../operator_converters/onnx/tree_ensemble.py | 13 +- .../sklearn/decision_tree.py | 8 +- .../ml/operator_converters/sklearn/gbdt.py | 10 +- .../ml/operator_converters/sklearn/iforest.py | 4 +- hummingbird/ml/operator_converters/xgb.py | 18 +- tests/test_lightgbm_converter.py | 38 ---- tests/test_xgboost_converter.py | 38 ---- 11 files changed, 128 insertions(+), 324 deletions(-) diff --git a/hummingbird/ml/operator_converters/_gbdt_commons.py b/hummingbird/ml/operator_converters/_gbdt_commons.py index facc8973..406abdd1 100644 --- a/hummingbird/ml/operator_converters/_gbdt_commons.py +++ b/hummingbird/ml/operator_converters/_gbdt_commons.py @@ -15,7 +15,9 @@ from ._tree_commons import get_tree_params_and_type, get_parameters_for_tree_tra from ._tree_implementations import GEMMGBDTImpl, TreeTraversalGBDTImpl, PerfectTreeTraversalGBDTImpl, TreeImpl -def convert_gbdt_classifier_common(operator, tree_infos, get_tree_parameters, n_features, n_classes, classes=None, missing_val=None, extra_config={}): +def convert_gbdt_classifier_common( + operator, tree_infos, get_tree_parameters, n_features, n_classes, classes=None, extra_config={} +): """ Common converter for GBDT classifiers. @@ -25,7 +27,6 @@ def convert_gbdt_classifier_common(operator, tree_infos, get_tree_parameters, n_ n_features: The number of features input to the model n_classes: How many classes are expected. 1 for regression tasks classes: The classes used for classification. None if implementing a regression model - missing_val: The value to be treated as the missing value extra_config: Extra configuration used to properly implement the source tree Returns: @@ -47,10 +48,10 @@ def convert_gbdt_classifier_common(operator, tree_infos, get_tree_parameters, n_ if reorder_trees and n_classes > 1: tree_infos = [tree_infos[i * n_classes + j] for j in range(n_classes) for i in range(len(tree_infos) // n_classes)] - return convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes, missing_val, extra_config) + return convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes, extra_config) -def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes=None, missing_val=None, extra_config={}): +def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes=None, extra_config={}): """ Common converter for GBDT models. @@ -59,7 +60,6 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c get_tree_parameters: A function specifying how to parse the tree_infos into parameters n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model - missing_val: The value to be treated as the missing value extra_config: Extra configuration used to properly implement the source tree Returns: @@ -86,7 +86,6 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c tree_param.thresholds, tree_param.values, n_features, - tree_param.missings, extra_config, ) for tree_param in tree_parameters @@ -104,7 +103,6 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c tree_param.features, tree_param.thresholds, tree_param.values, - tree_param.missings, extra_config, ) for tree_param in tree_parameters @@ -163,8 +161,8 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c # Generate the tree implementation based on the selected strategy. if tree_type == TreeImpl.gemm: - return GEMMGBDTImpl(operator, net_parameters, n_features, classes, missing_val, extra_config) + return GEMMGBDTImpl(operator, net_parameters, n_features, classes, extra_config) if tree_type == TreeImpl.tree_trav: - return TreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, missing_val, extra_config) + return TreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, extra_config) else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav. - return PerfectTreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, missing_val, extra_config) + return PerfectTreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, extra_config) diff --git a/hummingbird/ml/operator_converters/_tree_commons.py b/hummingbird/ml/operator_converters/_tree_commons.py index cbd9a642..0863901c 100644 --- a/hummingbird/ml/operator_converters/_tree_commons.py +++ b/hummingbird/ml/operator_converters/_tree_commons.py @@ -31,7 +31,6 @@ class Node: feature: The feature used to make a decision (if not leaf node, ignored otherwise) threshold: The threshold used in the decision (if not leaf node, ignored otherwise) value: The value stored in the leaf (ignored if not leaf node). - missing: In the case of a missingle value chosen child node id """ self.id = id self.left = None @@ -39,7 +38,6 @@ class Node: self.feature = None self.threshold = None self.value = None - self.missing = None class TreeParameters: @@ -47,7 +45,7 @@ class TreeParameters: Class containing a convenient in-memory representation of a decision tree. """ - def __init__(self, lefts, rights, features, thresholds, values, missings=None): + def __init__(self, lefts, rights, features, thresholds, values): """ Args: lefts: The id of the left nodes @@ -55,14 +53,12 @@ class TreeParameters: feature: The features used to make decisions thresholds: The thresholds used in the decisions values: The value stored in the leaves - missings: In the case of a missing value which child node to select """ self.lefts = lefts self.rights = rights self.features = features self.thresholds = thresholds self.values = values - self.missings = missings def _find_max_depth(tree_parameters): @@ -188,7 +184,7 @@ def get_parameters_for_sklearn_common(tree_infos): return TreeParameters(lefts, rights, features, thresholds, values) -def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values, missings=None, extra_config={}): +def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values, extra_config={}): """ Common functions used by all tree algorithms to generate the parameters according to the tree_trav strategies. @@ -198,7 +194,7 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val features: The features used in the decision nodes thresholds: The thresholds used in the decision nodes values: The values stored in the leaf nodes - missings: In the case of a missing value which child node to select + Returns: An array containing the extracted parameters """ @@ -209,26 +205,18 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val rights = [2, -1, -1] features = [0, 0, 0] thresholds = [0, 0, 0] - if missings is not None: - missings = [2, -1, -1] n_classes = values.shape[1] if type(values) is np.ndarray else 1 values = np.array([np.zeros(n_classes), values[0], values[0]]) values.reshape(3, n_classes) ids = [i for i in range(len(lefts))] - if missings is not None: - nodes = list(zip(ids, lefts, rights, features, thresholds, values, missings)) - else: - nodes = list(zip(ids, lefts, rights, features, thresholds, values)) + nodes = list(zip(ids, lefts, rights, features, thresholds, values)) # Refactor the tree parameters in the proper format. nodes_map = {0: Node(0)} current_node = 0 for i, node in enumerate(nodes): - if missings is not None: - id, left, right, feature, threshold, value, missing = node - else: - id, left, right, feature, threshold, value = node + id, left, right, feature, threshold, value = node if left != -1: l_node = Node(left) @@ -252,13 +240,6 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val nodes_map[current_node].threshold = threshold nodes_map[current_node].value = value - if missings is not None: - m_node = l_node if missing == left else r_node - nodes_map[current_node].missing = m_node - - if missings[i] == -1: - missings[i] = id - current_node += 1 lefts = np.array(lefts) @@ -266,13 +247,11 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val features = np.array(features) thresholds = np.array(thresholds) values = np.array(values) - if missings is not None: - missings = np.array(missings) - return [nodes_map, ids, lefts, rights, features, thresholds, values, missings] + return [nodes_map, ids, lefts, rights, features, thresholds, values] -def get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, missings=None, classes=None, extra_config={}): +def get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, classes=None, extra_config={}): """ This function is used to generate tree parameters for sklearn trees. Includes SklearnRandomForestClassifier/Regressor, and SklearnGradientBoostingClassifier. @@ -283,7 +262,6 @@ def get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, va features: The features used in the decision nodes thresholds: The thresholds used in the decision nodes values: The values stored in the leaf nodes - missings: In the case of a missing value which child node to select classes: The list of class labels. None if regression model Returns: An array containing the extracted parameters @@ -298,10 +276,10 @@ def get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, va if constants.NUM_TREES in extra_config: values /= extra_config[constants.NUM_TREES] - return get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values, missings) + return get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values) -def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values, n_features, missings=None, extra_config={}): +def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values, n_features, extra_config={}): """ Common functions used by all tree algorithms to generate the parameters according to the GEMM strategy. @@ -312,7 +290,7 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values, thresholds: The thresholds used in the decision nodes values: The values stored in the leaf nodes n_features: The number of expected input features - missings: In the case of a missing value which child node to select + Returns: The weights and bias for the GEMM implementation """ @@ -327,35 +305,20 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values, rights = [2, -1, -1] features = [0, 0, 0] thresholds = [0, 0, 0] - if missings is not None: - missings = [2, -1, -1] n_classes = values.shape[1] values = np.array([np.zeros(n_classes), values[0], values[0]]) values.reshape(3, n_classes) - if missings is None: - missings = rights - # First hidden layer has all inequalities. hidden_weights = [] hidden_biases = [] - hidden_missing_biases = [] - for left, right, missing, feature, thresh in zip(lefts, rights, missings, features, thresholds): - if left != -1 or right != -1: + for left, feature, thresh in zip(lefts, features, thresholds): + if left != -1: hidden_weights.append([1 if i == feature else 0 for i in range(n_features)]) hidden_biases.append(thresh) - - if missing == right: - hidden_missing_biases.append(1) - else: - hidden_missing_biases.append(0) weights.append(np.array(hidden_weights).astype("float32")) biases.append(np.array(hidden_biases).astype("float32")) - # Missing value handling biases. - weights.append(None) - biases.append(np.array(hidden_missing_biases).astype("float32")) - n_splits = len(hidden_weights) # Second hidden layer has ANDs for each leaf of the decision tree. @@ -381,10 +344,10 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values, for j, p in enumerate(path[:-1]): num_leaves_before_p = list(lefts[:p]).count(-1) if path[j + 1] in lefts: - vec[p - num_leaves_before_p] = -1 - elif path[j + 1] in rights: - num_positive += 1 vec[p - num_leaves_before_p] = 1 + num_positive += 1 + elif path[j + 1] in rights: + vec[p - num_leaves_before_p] = -1 else: raise RuntimeError("Inconsistent state encountered while tree translation.") @@ -433,7 +396,6 @@ def convert_decision_ensemble_tree_common( tree_param.thresholds, tree_param.values, n_features, - tree_param.missings, extra_config, ) for tree_param in tree_parameters @@ -442,7 +404,7 @@ def convert_decision_ensemble_tree_common( net_parameters = [ get_parameters_for_tree_trav( - tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, tree_param.missings, extra_config, + tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, extra_config, ) for tree_param in tree_parameters ] diff --git a/hummingbird/ml/operator_converters/_tree_implementations.py b/hummingbird/ml/operator_converters/_tree_implementations.py index 564a140c..572befc5 100644 --- a/hummingbird/ml/operator_converters/_tree_implementations.py +++ b/hummingbird/ml/operator_converters/_tree_implementations.py @@ -53,14 +53,13 @@ class AbstractPyTorchTreeImpl(AbstracTreeImpl, torch.nn.Module): Abstract class definig the basic structure for tree-base models implemented in PyTorch. """ - def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes, missing_val, **kwargs): + def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs): """ Args: tree_parameters: The parameters defining the tree structure n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model n_classes: The total number of used classes - missing_val: The value to be treated as the missing value """ super(AbstractPyTorchTreeImpl, self).__init__(logical_operator, **kwargs) @@ -84,32 +83,23 @@ class AbstractPyTorchTreeImpl(AbstracTreeImpl, torch.nn.Module): self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False) self.perform_class_select = True - self.missing_val = missing_val - if self.missing_val in [None, np.nan]: - self.missing_val_op = torch.isnan - else: - def missing_val_op(x): - return x == self.missing_val - self.missing_val_op = missing_val_op - class GEMMTreeImpl(AbstractPyTorchTreeImpl): """ Class implementing the GEMM strategy in PyTorch for tree-base models. """ - def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes=None, missing_val=None, extra_config={}, **kwargs): + def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes=None, extra_config={}, **kwargs): """ Args: tree_parameters: The parameters defining the tree structure n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model n_classes: The total number of used classes - missing_val: The value to be treated as the missing value """ # If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes. - n_classes = n_classes if n_classes is not None else tree_parameters[0][0][3].shape[0] - super(GEMMTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, missing_val, **kwargs) + n_classes = n_classes if n_classes is not None else tree_parameters[0][0][2].shape[0] + super(GEMMTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs) # Initialize the actual model. hidden_one_size = 0 @@ -118,24 +108,22 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl): for weight, bias in tree_parameters: hidden_one_size = max(hidden_one_size, weight[0].shape[0]) - hidden_two_size = max(hidden_two_size, weight[2].shape[0]) + hidden_two_size = max(hidden_two_size, weight[1].shape[0]) n_trees = len(tree_parameters) - weight_1 = np.zeros((n_trees, hidden_one_size)) + weight_1 = np.zeros((n_trees, hidden_one_size, n_features)) bias_1 = np.zeros((n_trees, hidden_one_size)) - missing_bias_1 = np.zeros((n_trees, hidden_one_size)) weight_2 = np.zeros((n_trees, hidden_two_size, hidden_one_size)) bias_2 = np.zeros((n_trees, hidden_two_size)) weight_3 = np.zeros((n_trees, hidden_three_size, hidden_two_size)) for i, (weight, bias) in enumerate(tree_parameters): if len(weight[0]) > 0: - weight_1[i, 0 : weight[0].shape[0]] = np.argmax(weight[0], axis=1) + weight_1[i, 0 : weight[0].shape[0], 0 : weight[0].shape[1]] = weight[0] bias_1[i, 0 : bias[0].shape[0]] = bias[0] - missing_bias_1[i, 0 : bias[1].shape[0]] = bias[1] - weight_2[i, 0 : weight[2].shape[0], 0 : weight[2].shape[1]] = weight[2] - bias_2[i, 0 : bias[2].shape[0]] = bias[2] - weight_3[i, 0 : weight[3].shape[0], 0 : weight[3].shape[1]] = weight[3] + weight_2[i, 0 : weight[1].shape[0], 0 : weight[1].shape[1]] = weight[1] + bias_2[i, 0 : bias[1].shape[0]] = bias[1] + weight_3[i, 0 : weight[2].shape[0], 0 : weight[2].shape[1]] = weight[2] self.n_trees = n_trees self.n_features = n_features @@ -143,19 +131,13 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl): self.hidden_two_size = hidden_two_size self.hidden_three_size = hidden_three_size - self.weight_1 = torch.nn.Parameter(torch.from_numpy(weight_1.reshape(-1).astype("int64")), requires_grad=False) - self.bias_1 = torch.nn.Parameter(torch.from_numpy(bias_1.reshape(1, -1).astype("float32")), requires_grad=False) + self.weight_1 = torch.nn.Parameter(torch.from_numpy(weight_1.reshape(-1, self.n_features).astype("float32"))) + self.bias_1 = torch.nn.Parameter(torch.from_numpy(bias_1.reshape(-1, 1).astype("float32"))) - # By default when we compare nan to any value the output will be false. Thus we need to explicitly - # account for missing values only when missings are different to lefts (i.e., False condition) - if np.sum(missing_bias_1) != 0: - self.missing_bias_1 = torch.nn.Parameter(torch.from_numpy(missing_bias_1.reshape(1, -1).astype("float32")), requires_grad=False) - else: - self.missing_bias_1 = None + self.weight_2 = torch.nn.Parameter(torch.from_numpy(weight_2.astype("float32"))) + self.bias_2 = torch.nn.Parameter(torch.from_numpy(bias_2.reshape(-1, 1).astype("float32"))) - self.weight_2 = torch.nn.Parameter(torch.from_numpy(weight_2.astype("float32")), requires_grad=False) - self.bias_2 = torch.nn.Parameter(torch.from_numpy(bias_2.reshape(-1, 1).astype("float32")), requires_grad=False) - self.weight_3 = torch.nn.Parameter(torch.from_numpy(weight_3.astype("float32")), requires_grad=False) + self.weight_3 = torch.nn.Parameter(torch.from_numpy(weight_3.astype("float32"))) # We register also base_prediction here so that tensor will be moved to the proper hardware with the model. # i.e., if cuda is selected, the parameter will be automatically moved on the GPU. @@ -166,12 +148,10 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl): return x def forward(self, x): - features = torch.index_select(x, 1, self.weight_1) - if self.missing_bias_1 is not None: - x = torch.where(self.missing_val_op(features), self.missing_bias_1 + torch.zeros_like(features), (features >= self.bias_1).float()) - else: - x = (features >= self.bias_1).float() - x = x.view(-1, self.n_trees * self.hidden_one_size).t().view(self.n_trees, self.hidden_one_size, -1) + x = x.t() + x = torch.mm(self.weight_1, x) < self.bias_1 + x = x.view(self.n_trees, self.hidden_one_size, -1) + x = x.float() x = torch.matmul(self.weight_2, x) @@ -207,7 +187,9 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl): indexes = indexes.expand(batch_size, self.num_trees) return indexes.reshape(-1) - def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes, n_classes=None, missing_val=None, extra_config={}, **kwargs): + def __init__( + self, logical_operator, tree_parameters, max_depth, n_features, classes, n_classes=None, extra_config={}, **kwargs + ): """ Args: tree_parameters: The parameters defining the tree structure @@ -215,12 +197,13 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl): n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model n_classes: The total number of used classes - missing_val: The value to be treated as the missing value extra_config: Extra configuration used to properly implement the source tree """ # If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes. n_classes = n_classes if n_classes is not None else tree_parameters[0][6].shape[1] - super(TreeTraversalTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, missing_val, **kwargs) + super(TreeTraversalTreeImpl, self).__init__( + logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs + ) # Initialize the actual model. self.n_features = n_features @@ -235,28 +218,16 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl): thresholds = np.zeros((self.num_trees, self.num_nodes), dtype=np.float32) values = np.zeros((self.num_trees, self.num_nodes, self.n_classes), dtype=np.float32) - missings = None - if len(tree_parameters[0]) == 8 and tree_parameters[0][7] is not None: - missings = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64) - for i in range(self.num_trees): lefts[i][: len(tree_parameters[i][0])] = tree_parameters[i][2] rights[i][: len(tree_parameters[i][0])] = tree_parameters[i][3] features[i][: len(tree_parameters[i][0])] = tree_parameters[i][4] thresholds[i][: len(tree_parameters[i][0])] = tree_parameters[i][5] values[i][: len(tree_parameters[i][0])][:] = tree_parameters[i][6] - if missings is not None: - missings[i][: len(tree_parameters[i][0])] = tree_parameters[i][7] self.lefts = torch.nn.Parameter(torch.from_numpy(lefts).view(-1), requires_grad=False) self.rights = torch.nn.Parameter(torch.from_numpy(rights).view(-1), requires_grad=False) - # By default when we compare nan to any value the output will be false. Thus we need to explicitly - # account for missing values when missings are different to lefts (i.e., false condition) - self.missings = None - if missings is not None and not np.allclose(lefts, missings): - self.missings = torch.nn.Parameter(torch.from_numpy(missings).view(-1), requires_grad=False) - self.features = torch.nn.Parameter(torch.from_numpy(features).view(-1), requires_grad=False) self.thresholds = torch.nn.Parameter(torch.from_numpy(thresholds).view(-1)) self.values = torch.nn.Parameter(torch.from_numpy(values).view(-1, self.n_classes)) @@ -284,12 +255,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl): lefts = torch.index_select(self.lefts, 0, indexes).view(-1, self.num_trees) rights = torch.index_select(self.rights, 0, indexes).view(-1, self.num_trees) - if self.missings is not None: - missings = torch.index_select(self.missings, 0, indexes).view(-1, self.num_trees) - indexes = torch.where(torch.ge(feature_values, thresholds), rights, lefts).long() - if self.missings is not None: - indexes = torch.where(self.missing_val_op(feature_values), missings, indexes) indexes = indexes + self.nodes_offset indexes = indexes.view(-1) @@ -315,19 +281,22 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl): Class implementing the Perfect Tree Traversal strategy in PyTorch for tree-base models. """ - def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes, n_classes=None, missing_val=None, extra_config={}, **kwargs): + def __init__( + self, logical_operator, tree_parameters, max_depth, n_features, classes, n_classes=None, extra_config={}, **kwargs + ): """ Args: tree_parameters: The parameters defining the tree structure max_depth: The maximum tree-depth in the model n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model - missing_val: The value to be treated as the missing value n_classes: The total number of used classes """ # If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes. n_classes = n_classes if n_classes is not None else tree_parameters[0][6].shape[1] - super(PerfectTreeTraversalTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, missing_val, **kwargs) + super(PerfectTreeTraversalTreeImpl, self).__init__( + logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs + ) # Initialize the actual model. self.max_tree_depth = max_depth @@ -336,53 +305,40 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl): node_maps = [tp[0] for tp in tree_parameters] - feature_ids = np.zeros((self.num_trees, 2 ** max_depth - 1)) - threshold_vals = np.zeros((self.num_trees, 2 ** max_depth - 1)) - leaf_vals = np.zeros((self.num_trees, 2 ** max_depth, self.n_classes)) - missings = np.zeros((self.num_trees, 2 ** max_depth - 1), dtype=np.int64) + weight_0 = np.zeros((self.num_trees, 2 ** max_depth - 1)) + bias_0 = np.zeros((self.num_trees, 2 ** max_depth - 1)) + weight_1 = np.zeros((self.num_trees, 2 ** max_depth, self.n_classes)) - # By default when we compare nan to any value the output will be false. Thus, in `_populate_structure_tensors` we check whether there are - # non-trivial missings that are different to lefts (i.e., false condition) and set the `self.has_non_trivial_missing_vals` to True. - self.has_non_trivial_missing_vals = False for i, node_map in enumerate(node_maps): - self._populate_structure_tensors(node_map, max_depth, feature_ids[i], leaf_vals[i], threshold_vals[i], missings[i]) + self._get_weights_and_biases(node_map, max_depth, weight_0[i], weight_1[i], bias_0[i]) node_by_levels = [set() for _ in range(max_depth)] self._traverse_by_level(node_by_levels, 0, -1, max_depth) - self.root_nodes = torch.nn.Parameter(torch.from_numpy(feature_ids[:, 0].flatten().astype("int64")), requires_grad=False) - self.root_biases = torch.nn.Parameter(-1 * torch.from_numpy(threshold_vals[:, 0].astype("float32")), requires_grad=False) - self.root_missing_node_ids = torch.nn.Parameter(torch.from_numpy(missings[:, 0].astype("int64")), requires_grad=False) + self.root_nodes = torch.nn.Parameter(torch.from_numpy(weight_0[:, 0].flatten().astype("int64")), requires_grad=False) + self.root_biases = torch.nn.Parameter(-1 * torch.from_numpy(bias_0[:, 0].astype("float32")), requires_grad=False) tree_indices = np.array([i for i in range(0, 2 * self.num_trees, 2)]).astype("int64") self.tree_indices = torch.nn.Parameter(torch.from_numpy(tree_indices), requires_grad=False) - self.feature_ids = [] - self.threshold_vals = [] - self.missing_node_ids = [] + self.nodes = [] + self.biases = [] for i in range(1, max_depth): - features = torch.nn.Parameter( - torch.from_numpy(feature_ids[:, list(sorted(node_by_levels[i]))].flatten().astype("int64")), requires_grad=False + nodes = torch.nn.Parameter( + torch.from_numpy(weight_0[:, list(sorted(node_by_levels[i]))].flatten().astype("int64")), requires_grad=False ) - thresholds = torch.nn.Parameter( - torch.from_numpy(-1 * threshold_vals[:, list(sorted(node_by_levels[i]))].flatten().astype("float32")), - requires_grad=False, - ) - missing_nodes = torch.nn.Parameter( - torch.from_numpy(missings[:, list(sorted(node_by_levels[i]))].flatten().astype("int64")), + biases = torch.nn.Parameter( + torch.from_numpy(-1 * bias_0[:, list(sorted(node_by_levels[i]))].flatten().astype("float32")), requires_grad=False, ) + self.nodes.append(nodes) + self.biases.append(biases) - self.feature_ids.append(features) - self.threshold_vals.append(thresholds) - self.missing_node_ids.append(missing_nodes) - - self.feature_ids = torch.nn.ParameterList(self.feature_ids) - self.threshold_vals = torch.nn.ParameterList(self.threshold_vals) - self.missing_node_ids = torch.nn.ParameterList(self.missing_node_ids) + self.nodes = torch.nn.ParameterList(self.nodes) + self.biases = torch.nn.ParameterList(self.biases) self.leaf_nodes = torch.nn.Parameter( - torch.from_numpy(leaf_vals.reshape((-1, self.n_classes)).astype("float32")), requires_grad=False + torch.from_numpy(weight_1.reshape((-1, self.n_classes)).astype("float32")), requires_grad=False ) # We register also base_prediction here so that tensor will be moved to the proper hardware with the model. @@ -394,23 +350,15 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl): return x def forward(self, x): - root_features = torch.index_select(x, 1, self.root_nodes) - prev_indices = torch.ge(root_features, self.root_biases).long() - if self.has_non_trivial_missing_vals: - prev_indices = torch.where(self.missing_val_op(root_features), self.root_missing_node_ids, prev_indices) + prev_indices = (torch.ge(torch.index_select(x, 1, self.root_nodes), self.root_biases)).long() prev_indices = prev_indices + self.tree_indices prev_indices = prev_indices.view(-1) factor = 2 - for features, thresholds, missings in zip(self.feature_ids, self.threshold_vals, self.missing_node_ids): - gather_indices = torch.index_select(features, 0, prev_indices).view(-1, self.num_trees) + for nodes, biases in zip(self.nodes, self.biases): + gather_indices = torch.index_select(nodes, 0, prev_indices).view(-1, self.num_trees) features = torch.gather(x, 1, gather_indices).view(-1) - thresholds = torch.index_select(thresholds, 0, prev_indices) - node_eval_status = torch.ge(features, thresholds).long() - if self.has_non_trivial_missing_vals: - missings = torch.index_select(missings, 0, prev_indices) - node_eval_status = torch.where(self.missing_val_op(features), missings, node_eval_status) - prev_indices = factor * prev_indices + node_eval_status + prev_indices = factor * prev_indices + torch.ge(features, torch.index_select(biases, 0, prev_indices)).long() output = torch.index_select(self.leaf_nodes, 0, prev_indices).view(-1, self.num_trees, self.n_classes) @@ -438,24 +386,17 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl): node_id = self._traverse_by_level(node_by_levels, node_id, current_level, max_level) return node_id - def _populate_structure_tensors(self, nodes_map, tree_depth, node_ids, threshold_vals, leaf_vals, missings): + def _get_weights_and_biases(self, nodes_map, tree_depth, weight_0, weight_1, bias_0): def depth_f_traversal(node, current_depth, node_id, leaf_start_id): - node_ids[node_id] = node.feature - leaf_vals[node_id] = -node.threshold - - if node.missing is None or node.left == node.missing: - missings[node_id] = 0 - else: - missings[node_id] = 1 - self.has_non_trivial_missing_vals = True - + weight_0[node_id] = node.feature + bias_0[node_id] = -node.threshold current_depth += 1 node_id += 1 if node.left.feature == -1: node_id += 2 ** (tree_depth - current_depth - 1) - 1 v = node.left.value - threshold_vals[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = ( + weight_1[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = ( np.ones((2 ** (tree_depth - current_depth - 1), self.n_classes)) * v ) leaf_start_id += 2 ** (tree_depth - current_depth - 1) @@ -465,7 +406,7 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl): if node.right.feature == -1: node_id += 2 ** (tree_depth - current_depth - 1) - 1 v = node.right.value - threshold_vals[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = ( + weight_1[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = ( np.ones((2 ** (tree_depth - current_depth - 1), self.n_classes)) * v ) leaf_start_id += 2 ** (tree_depth - current_depth - 1) @@ -484,15 +425,14 @@ class GEMMDecisionTreeImpl(GEMMTreeImpl): """ - def __init__(self, logical_operator, tree_parameters, n_features, classes=None, missing_val=None): + def __init__(self, logical_operator, tree_parameters, n_features, classes=None): """ Args: tree_parameters: The parameters defining the tree structure n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model - missing_val: The value to be treated as the missing value """ - super(GEMMDecisionTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, missing_val) + super(GEMMDecisionTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes) def aggregation(self, x): output = x.sum(0).t() @@ -505,14 +445,13 @@ class TreeTraversalDecisionTreeImpl(TreeTraversalTreeImpl): Class implementing the Tree Traversal strategy in PyTorch for decision tree models. """ - def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, missing_val=None, extra_config={}): + def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}): """ Args: tree_parameters: The parameters defining the tree structure max_depth: The maximum tree-depth in the model n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model - missing_val: The value to be treated as the missing value extra_config: Extra configuration used to properly implement the source tree """ super(TreeTraversalDecisionTreeImpl, self).__init__( @@ -530,16 +469,17 @@ class PerfectTreeTraversalDecisionTreeImpl(PerfectTreeTraversalTreeImpl): Class implementing the Perfect Tree Traversal strategy in PyTorch for decision tree models. """ - def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, missing_val=None): + def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None): """ Args: tree_parameters: The parameters defining the tree structure max_depth: The maximum tree-depth in the model n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model - missing_val: The value to be treated as the missing value """ - super(PerfectTreeTraversalDecisionTreeImpl, self).__init__(logical_operator, tree_parameters, max_depth, n_features, classes, missing_val) + super(PerfectTreeTraversalDecisionTreeImpl, self).__init__( + logical_operator, tree_parameters, max_depth, n_features, classes + ) def aggregation(self, x): output = x.sum(1) @@ -553,16 +493,15 @@ class GEMMGBDTImpl(GEMMTreeImpl): Class implementing the GEMM strategy (in PyTorch) for GBDT models. """ - def __init__(self, logical_operator, tree_parameters, n_features, classes=None, missing_val=None, extra_config={}): + def __init__(self, logical_operator, tree_parameters, n_features, classes=None, extra_config={}): """ Args: tree_parameters: The parameters defining the tree structure n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model - missing_val: The value to be treated as the missing value extra_config: Extra configuration used to properly implement the source tree """ - super(GEMMGBDTImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, 1, missing_val, extra_config) + super(GEMMGBDTImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, 1, extra_config) self.n_gbdt_classes = 1 self.post_transform = lambda x: x @@ -586,17 +525,18 @@ class TreeTraversalGBDTImpl(TreeTraversalTreeImpl): Class implementing the Tree Traversal strategy in PyTorch. """ - def __init__(self, logical_operator, tree_parameters, max_detph, n_features, classes=None, missing_val=None, extra_config={}): + def __init__(self, logical_operator, tree_parameters, max_detph, n_features, classes=None, extra_config={}): """ Args: tree_parameters: The parameters defining the tree structure max_depth: The maximum tree-depth in the model n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model - missing_val: The value to be treated as the missing_val value extra_config: Extra configuration used to properly implement the source tree """ - super(TreeTraversalGBDTImpl, self).__init__(logical_operator, tree_parameters, max_detph, n_features, classes, 1, missing_val, extra_config) + super(TreeTraversalGBDTImpl, self).__init__( + logical_operator, tree_parameters, max_detph, n_features, classes, 1, extra_config + ) self.n_gbdt_classes = 1 self.post_transform = lambda x: x @@ -620,17 +560,18 @@ class PerfectTreeTraversalGBDTImpl(PerfectTreeTraversalTreeImpl): Class implementing the Perfect Tree Traversal strategy in PyTorch. """ - def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, missing_val=None, extra_config={}): + def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}): """ Args: tree_parameters: The parameters defining the tree structure max_depth: The maximum tree-depth in the model n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model - missing_val: The value to be treated as the missing_val value extra_config: Extra configuration used to properly implement the source tree """ - super(PerfectTreeTraversalGBDTImpl, self).__init__(logical_operator, tree_parameters, max_depth, n_features, classes, 1, missing_val, extra_config) + super(PerfectTreeTraversalGBDTImpl, self).__init__( + logical_operator, tree_parameters, max_depth, n_features, classes, 1, extra_config + ) self.n_gbdt_classes = 1 self.post_transform = lambda x: x diff --git a/hummingbird/ml/operator_converters/lightgbm.py b/hummingbird/ml/operator_converters/lightgbm.py index 06e38f09..b3e595df 100644 --- a/hummingbird/ml/operator_converters/lightgbm.py +++ b/hummingbird/ml/operator_converters/lightgbm.py @@ -16,7 +16,7 @@ from ._gbdt_commons import convert_gbdt_classifier_common, convert_gbdt_common from ._tree_commons import TreeParameters -def _tree_traversal(node, lefts, rights, features, thresholds, values, missings, count): +def _tree_traversal(node, lefts, rights, features, thresholds, values, count): """ Recursive function for parsing a tree and filling the input data structures. """ @@ -26,23 +26,16 @@ def _tree_traversal(node, lefts, rights, features, thresholds, values, missings, values.append([-1]) lefts.append(count + 1) rights.append(-1) - missings.append(-1) pos = len(rights) - 1 - count = _tree_traversal(node["left_child"], lefts, rights, features, thresholds, values, missings, count + 1) + count = _tree_traversal(node["left_child"], lefts, rights, features, thresholds, values, count + 1) rights[pos] = count + 1 - if node['missing_type'] == 'None': - # Missing values not present in training data are treated as zeros during inference. - missings[pos] = lefts[pos] if 0 < node["threshold"] else rights[pos] - else: - missings[pos] = lefts[pos] if node["default_left"] else rights[pos] - return _tree_traversal(node["right_child"], lefts, rights, features, thresholds, values, missings, count + 1) + return _tree_traversal(node["right_child"], lefts, rights, features, thresholds, values, count + 1) else: features.append(0) thresholds.append(0) values.append([node["leaf_value"]]) lefts.append(-1) rights.append(-1) - missings.append(-1) return count @@ -55,10 +48,9 @@ def _get_tree_parameters(tree_info): features = [] thresholds = [] values = [] - missings = [] - _tree_traversal(tree_info["tree_structure"], lefts, rights, features, thresholds, values, missings, 0) + _tree_traversal(tree_info["tree_structure"], lefts, rights, features, thresholds, values, 0) - return TreeParameters(lefts, rights, features, thresholds, values, missings) + return TreeParameters(lefts, rights, features, thresholds, values) def convert_sklearn_lgbm_classifier(operator, device, extra_config): @@ -74,14 +66,14 @@ def convert_sklearn_lgbm_classifier(operator, device, extra_config): A PyTorch model """ assert operator is not None, "Cannot convert None operator" - assert not hasattr(operator.raw_operator, "use_missing") or operator.raw_operator.use_missing - assert not hasattr(operator.raw_operator, "zero_as_missing") or not operator.raw_operator.zero_as_missing n_features = operator.raw_operator._n_features tree_infos = operator.raw_operator.booster_.dump_model()["tree_info"] n_classes = operator.raw_operator._n_classes - return convert_gbdt_classifier_common(operator, tree_infos, _get_tree_parameters, n_features, n_classes, missing_val=None, extra_config=extra_config) + return convert_gbdt_classifier_common( + operator, tree_infos, _get_tree_parameters, n_features, n_classes, extra_config=extra_config + ) def convert_sklearn_lgbm_regressor(operator, device, extra_config): @@ -97,8 +89,6 @@ def convert_sklearn_lgbm_regressor(operator, device, extra_config): A PyTorch model """ assert operator is not None, "Cannot convert None operator" - assert not hasattr(operator.raw_operator, "use_missing") or operator.raw_operator.use_missing - assert not hasattr(operator.raw_operator, "zero_as_missing") or not operator.raw_operator.zero_as_missing # Get tree information out of the model. n_features = operator.raw_operator._n_features @@ -106,7 +96,7 @@ def convert_sklearn_lgbm_regressor(operator, device, extra_config): if operator.raw_operator._objective == "tweedie": extra_config[constants.POST_TRANSFORM] = constants.TWEEDIE - return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, missing_val=None, extra_config=extra_config) + return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, extra_config=extra_config) # Register the converters. diff --git a/hummingbird/ml/operator_converters/onnx/tree_ensemble.py b/hummingbird/ml/operator_converters/onnx/tree_ensemble.py index 43efb417..57515c41 100644 --- a/hummingbird/ml/operator_converters/onnx/tree_ensemble.py +++ b/hummingbird/ml/operator_converters/onnx/tree_ensemble.py @@ -214,14 +214,7 @@ def convert_onnx_tree_ensemble_classifier(operator, device=None, extra_config={} ) extra_config[constants.POST_TRANSFORM] = post_transform return convert_gbdt_classifier_common( - operator, - tree_infos, - _dummy_get_parameter, - n_features, - len(classes), - classes, - missing_val=None, - extra_config=extra_config, + operator, tree_infos, _dummy_get_parameter, n_features, len(classes), classes, extra_config ) @@ -243,9 +236,7 @@ def convert_onnx_tree_ensemble_regressor(operator, device=None, extra_config={}) n_features, tree_infos, _, _ = _get_tree_infos_from_tree_ensemble(operator.raw_operator, device, extra_config) # Generate the model. - return convert_gbdt_common( - operator, tree_infos, _dummy_get_parameter, n_features, missing_val=None, extra_config=extra_config - ) + return convert_gbdt_common(operator, tree_infos, _dummy_get_parameter, n_features, extra_config=extra_config) register_converter("ONNXMLTreeEnsembleClassifier", convert_onnx_tree_ensemble_classifier) diff --git a/hummingbird/ml/operator_converters/sklearn/decision_tree.py b/hummingbird/ml/operator_converters/sklearn/decision_tree.py index d37029e2..24cf1440 100644 --- a/hummingbird/ml/operator_converters/sklearn/decision_tree.py +++ b/hummingbird/ml/operator_converters/sklearn/decision_tree.py @@ -42,8 +42,8 @@ def convert_sklearn_random_forest_classifier(operator, device, extra_config): if not all(isinstance(c, int) for c in classes): raise RuntimeError("Random Forest Classifier translation only supports integer class labels") - def get_parameters_for_tree_trav(lefts, rights, features, thresholds, values, missings=None, extra_config={}): - return get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, missings, classes, extra_config) + def get_parameters_for_tree_trav(lefts, rights, features, thresholds, values, extra_config={}): + return get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, classes, extra_config) return convert_decision_ensemble_tree_common( operator, @@ -77,8 +77,8 @@ def convert_sklearn_random_forest_regressor(operator, device, extra_config): # For Sklearn Trees we need to know how many trees are there for normalization. extra_config[constants.NUM_TREES] = len(tree_infos) - def get_parameters_for_tree_trav(lefts, rights, features, thresholds, values, missings=None, extra_config={}): - return get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, missings, None, extra_config) + def get_parameters_for_tree_trav(lefts, rights, features, thresholds, values, extra_config={}): + return get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, None, extra_config) return convert_decision_ensemble_tree_common( operator, diff --git a/hummingbird/ml/operator_converters/sklearn/gbdt.py b/hummingbird/ml/operator_converters/sklearn/gbdt.py index 7697b1ca..27d485a1 100644 --- a/hummingbird/ml/operator_converters/sklearn/gbdt.py +++ b/hummingbird/ml/operator_converters/sklearn/gbdt.py @@ -90,7 +90,7 @@ def convert_sklearn_gbdt_classifier(operator, device, extra_config): extra_config[constants.REORDER_TREES] = False return convert_gbdt_classifier_common( - operator, tree_infos, get_parameters_for_sklearn_common, n_features, n_classes, classes, missing_val=None, extra_config=extra_config + operator, tree_infos, get_parameters_for_sklearn_common, n_features, n_classes, classes, extra_config ) @@ -126,7 +126,7 @@ def convert_sklearn_gbdt_regressor(operator, device, extra_config): extra_config[constants.BASE_PREDICTION] = base_prediction - return convert_gbdt_common(operator, tree_infos, get_parameters_for_sklearn_common, n_features, missing_val=None, extra_config=extra_config) + return convert_gbdt_common(operator, tree_infos, get_parameters_for_sklearn_common, n_features, None, extra_config) def convert_sklearn_hist_gbdt_classifier(operator, device, extra_config): @@ -167,7 +167,9 @@ def convert_sklearn_hist_gbdt_classifier(operator, device, extra_config): extra_config[constants.BASE_PREDICTION] = base_prediction extra_config[constants.REORDER_TREES] = False - return convert_gbdt_classifier_common(operator, tree_infos, _get_parameters_hist_gbdt, n_features, n_classes, classes, missing_val=None, extra_config=extra_config) + return convert_gbdt_classifier_common( + operator, tree_infos, _get_parameters_hist_gbdt, n_features, n_classes, classes, extra_config + ) def convert_sklearn_hist_gbdt_regressor(operator, device, extra_config): @@ -190,7 +192,7 @@ def convert_sklearn_hist_gbdt_regressor(operator, device, extra_config): n_features = operator.raw_operator.n_features_ extra_config[constants.BASE_PREDICTION] = [[operator.raw_operator._baseline_prediction]] - return convert_gbdt_common(operator, tree_infos, _get_parameters_hist_gbdt, n_features, missing_val=None, extra_config=extra_config) + return convert_gbdt_common(operator, tree_infos, _get_parameters_hist_gbdt, n_features, None, extra_config) # Register the converters. diff --git a/hummingbird/ml/operator_converters/sklearn/iforest.py b/hummingbird/ml/operator_converters/sklearn/iforest.py index b16c97ab..48442b89 100644 --- a/hummingbird/ml/operator_converters/sklearn/iforest.py +++ b/hummingbird/ml/operator_converters/sklearn/iforest.py @@ -239,7 +239,7 @@ def convert_sklearn_isolation_forest(operator, device, extra_config): if tree_type == TreeImpl.gemm: net_parameters = [ get_parameters_for_gemm_common( - tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, n_features, tree_param.missings + tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, n_features ) for tree_param in tree_parameters ] @@ -247,7 +247,7 @@ def convert_sklearn_isolation_forest(operator, device, extra_config): net_parameters = [ get_parameters_for_tree_trav_sklearn( - tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, tree_param.missings, classes + tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, classes ) for tree_param in tree_parameters ] diff --git a/hummingbird/ml/operator_converters/xgb.py b/hummingbird/ml/operator_converters/xgb.py index e6d657fc..ca3d5cb2 100644 --- a/hummingbird/ml/operator_converters/xgb.py +++ b/hummingbird/ml/operator_converters/xgb.py @@ -16,7 +16,7 @@ from ._gbdt_commons import convert_gbdt_classifier_common, convert_gbdt_common from ._tree_commons import TreeParameters -def _tree_traversal(tree_info, lefts, rights, missing, features, thresholds, values): +def _tree_traversal(tree_info, lefts, rights, features, thresholds, values): """ Recursive function for parsing a tree and filling the input data structures. """ @@ -28,7 +28,6 @@ def _tree_traversal(tree_info, lefts, rights, missing, features, thresholds, val values.append([float(tree_info[count].split("=")[1])]) lefts.append(-1) rights.append(-1) - missing.append(-1) count += 1 else: features.append(int(tree_info[count].split(":")[1].split("<")[0].replace("[f", ""))) @@ -57,8 +56,6 @@ def _tree_traversal(tree_info, lefts, rights, missing, features, thresholds, val r_correct_id += 1 rights.append(r_correct_id) - missing_wrong_id = tree_info[count].split(",")[2].replace("missing=", "") - missing.append(l_correct_id if l_wrong_id == missing_wrong_id else r_correct_id) count += 1 @@ -68,15 +65,14 @@ def _get_tree_parameters(tree_info): """ lefts = [] rights = [] - missing = [] features = [] thresholds = [] values = [] _tree_traversal( - tree_info.replace("[f", "").replace("[", "").replace("]", "").split(), lefts, rights, missing, features, thresholds, values + tree_info.replace("[f", "").replace("[", "").replace("]", "").split(), lefts, rights, features, thresholds, values ) - return TreeParameters(lefts, rights, features, thresholds, values, missing) + return TreeParameters(lefts, rights, features, thresholds, values) def convert_sklearn_xgb_classifier(operator, device, extra_config): @@ -101,9 +97,10 @@ def convert_sklearn_xgb_classifier(operator, device, extra_config): ) tree_infos = operator.raw_operator.get_booster().get_dump() n_classes = operator.raw_operator.n_classes_ - missing_val = operator.raw_operator.missing - return convert_gbdt_classifier_common(operator, tree_infos, _get_tree_parameters, n_features, n_classes, missing_val=missing_val, extra_config=extra_config) + return convert_gbdt_classifier_common( + operator, tree_infos, _get_tree_parameters, n_features, n_classes, extra_config=extra_config + ) def convert_sklearn_xgb_regressor(operator, device, extra_config): @@ -135,9 +132,8 @@ def convert_sklearn_xgb_regressor(operator, device, extra_config): base_prediction = [base_prediction] extra_config[constants.BASE_PREDICTION] = base_prediction - missing_val = operator.raw_operator.missing - return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, missing_val=missing_val, extra_config=extra_config) + return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, extra_config=extra_config) # Register the converters. diff --git a/tests/test_lightgbm_converter.py b/tests/test_lightgbm_converter.py index 3d25146d..ddd9af5f 100644 --- a/tests/test_lightgbm_converter.py +++ b/tests/test_lightgbm_converter.py @@ -9,7 +9,6 @@ import numpy as np import hummingbird.ml from hummingbird.ml._utils import lightgbm_installed, onnx_runtime_installed, tvm_installed from tree_utils import gbdt_implementation_map -from sklearn.datasets import make_classification, make_regression if lightgbm_installed(): import lightgbm as lgb @@ -259,43 +258,6 @@ class TestLGBMConverter(unittest.TestCase): self.assertIsNotNone(torch_model) np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06) - # Missing values test. - @unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed") - def test_run_lgbm_classifier_w_missing_vals_converter(self): - warnings.filterwarnings("ignore") - for extra_config_param in ["gemm", "tree_trav", "perf_tree_trav"]: - for missing in [None, np.nan]: - for model_class, n_classes in zip([lgb.LGBMClassifier, lgb.LGBMClassifier, lgb.LGBMRegressor], [2, 3, None]): - model = model_class(use_missing=True, zero_as_missing=False) - # Missing values during training + inference. - if model_class == lgb.LGBMClassifier: - X, y = make_classification(n_samples=100, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021) - else: - X, y = make_regression(n_samples=100, n_features=3, n_informative=3, random_state=2021) - X[:25][y[:25] == 0, 0] = np.nan if missing is None else missing - model.fit(X, y) - torch_model = hummingbird.ml.convert(model, "torch", X, extra_config={"tree_implementation": extra_config_param}) - self.assertIsNotNone(torch_model) - if model_class == lgb.LGBMClassifier: - np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06) - else: - np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06) - - # Missing values during only inference. - model = model_class(use_missing=True, zero_as_missing=False) - if model_class == lgb.LGBMClassifier: - X, y = make_classification(n_samples=100, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021) - else: - X, y = make_regression(n_samples=100, n_features=3, n_informative=3, random_state=2021) - model.fit(X, y) - torch_model = hummingbird.ml.convert(model, "torch", X, extra_config={"tree_implementation": extra_config_param}) - X[:25][y[:25] == 0, 0] = np.nan if missing is None else missing - self.assertIsNotNone(torch_model) - if model_class == lgb.LGBMClassifier: - np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06) - else: - np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06) - # Backend tests. # Test TorchScript backend regression. @unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed") diff --git a/tests/test_xgboost_converter.py b/tests/test_xgboost_converter.py index db8d14f6..4c4844fb 100644 --- a/tests/test_xgboost_converter.py +++ b/tests/test_xgboost_converter.py @@ -10,7 +10,6 @@ import hummingbird.ml from hummingbird.ml._utils import xgboost_installed, tvm_installed from hummingbird.ml import constants from tree_utils import gbdt_implementation_map -from sklearn.datasets import make_classification, make_regression if xgboost_installed(): import xgboost as xgb @@ -225,43 +224,6 @@ class TestXGBoostConverter(unittest.TestCase): self.assertIsNotNone(torch_model) np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06) - # Missing values test. - @unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed") - def test_run_xgb_classifier_w_missing_vals_converter(self): - warnings.filterwarnings("ignore") - for extra_config_param in ["gemm", "tree_trav", "perf_tree_trav"]: - for missing in [None, -99999, np.nan]: - for model_class, n_classes in zip([xgb.XGBClassifier, xgb.XGBClassifier, xgb.XGBRegressor], [2, 3, None]): - model = model_class(missing=missing) - # Missing values during both training and inference. - if model_class == xgb.XGBClassifier: - X, y = make_classification(n_samples=100, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021) - else: - X, y = make_regression(n_samples=100, n_features=3, n_informative=3, random_state=2021) - X[:25][y[:25] == 0, 0] = np.nan if missing is None else missing - model.fit(X, y) - torch_model = hummingbird.ml.convert(model, "torch", X, extra_config={"tree_implementation": extra_config_param}) - self.assertIsNotNone(torch_model) - if model_class == xgb.XGBClassifier: - np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06) - else: - np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06) - - # Missing values during only inference. - model = model_class(missing=missing) - if model_class == xgb.XGBClassifier: - X, y = make_classification(n_samples=100, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021) - else: - X, y = make_regression(n_samples=100, n_features=3, n_informative=3, random_state=2021) - model.fit(X, y) - X[:25][y[:25] == 0, 0] = np.nan if missing is None else missing - torch_model = hummingbird.ml.convert(model, "torch", X, extra_config={"tree_implementation": extra_config_param}) - self.assertIsNotNone(torch_model) - if model_class == xgb.XGBClassifier: - np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06) - else: - np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06) - # Torchscript backends. # Test TorchScript backend regression. @unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")