This commit is contained in:
Matteo Interlandi 2021-02-02 17:11:00 -08:00 коммит произвёл GitHub
Родитель b88b6f7e02
Коммит c525c51060
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 128 добавлений и 324 удалений

Просмотреть файл

@ -15,7 +15,9 @@ from ._tree_commons import get_tree_params_and_type, get_parameters_for_tree_tra
from ._tree_implementations import GEMMGBDTImpl, TreeTraversalGBDTImpl, PerfectTreeTraversalGBDTImpl, TreeImpl
def convert_gbdt_classifier_common(operator, tree_infos, get_tree_parameters, n_features, n_classes, classes=None, missing_val=None, extra_config={}):
def convert_gbdt_classifier_common(
operator, tree_infos, get_tree_parameters, n_features, n_classes, classes=None, extra_config={}
):
"""
Common converter for GBDT classifiers.
@ -25,7 +27,6 @@ def convert_gbdt_classifier_common(operator, tree_infos, get_tree_parameters, n_
n_features: The number of features input to the model
n_classes: How many classes are expected. 1 for regression tasks
classes: The classes used for classification. None if implementing a regression model
missing_val: The value to be treated as the missing value
extra_config: Extra configuration used to properly implement the source tree
Returns:
@ -47,10 +48,10 @@ def convert_gbdt_classifier_common(operator, tree_infos, get_tree_parameters, n_
if reorder_trees and n_classes > 1:
tree_infos = [tree_infos[i * n_classes + j] for j in range(n_classes) for i in range(len(tree_infos) // n_classes)]
return convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes, missing_val, extra_config)
return convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes, extra_config)
def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes=None, missing_val=None, extra_config={}):
def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes=None, extra_config={}):
"""
Common converter for GBDT models.
@ -59,7 +60,6 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c
get_tree_parameters: A function specifying how to parse the tree_infos into parameters
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
missing_val: The value to be treated as the missing value
extra_config: Extra configuration used to properly implement the source tree
Returns:
@ -86,7 +86,6 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c
tree_param.thresholds,
tree_param.values,
n_features,
tree_param.missings,
extra_config,
)
for tree_param in tree_parameters
@ -104,7 +103,6 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c
tree_param.features,
tree_param.thresholds,
tree_param.values,
tree_param.missings,
extra_config,
)
for tree_param in tree_parameters
@ -163,8 +161,8 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c
# Generate the tree implementation based on the selected strategy.
if tree_type == TreeImpl.gemm:
return GEMMGBDTImpl(operator, net_parameters, n_features, classes, missing_val, extra_config)
return GEMMGBDTImpl(operator, net_parameters, n_features, classes, extra_config)
if tree_type == TreeImpl.tree_trav:
return TreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, missing_val, extra_config)
return TreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, extra_config)
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav.
return PerfectTreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, missing_val, extra_config)
return PerfectTreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, extra_config)

Просмотреть файл

@ -31,7 +31,6 @@ class Node:
feature: The feature used to make a decision (if not leaf node, ignored otherwise)
threshold: The threshold used in the decision (if not leaf node, ignored otherwise)
value: The value stored in the leaf (ignored if not leaf node).
missing: In the case of a missingle value chosen child node id
"""
self.id = id
self.left = None
@ -39,7 +38,6 @@ class Node:
self.feature = None
self.threshold = None
self.value = None
self.missing = None
class TreeParameters:
@ -47,7 +45,7 @@ class TreeParameters:
Class containing a convenient in-memory representation of a decision tree.
"""
def __init__(self, lefts, rights, features, thresholds, values, missings=None):
def __init__(self, lefts, rights, features, thresholds, values):
"""
Args:
lefts: The id of the left nodes
@ -55,14 +53,12 @@ class TreeParameters:
feature: The features used to make decisions
thresholds: The thresholds used in the decisions
values: The value stored in the leaves
missings: In the case of a missing value which child node to select
"""
self.lefts = lefts
self.rights = rights
self.features = features
self.thresholds = thresholds
self.values = values
self.missings = missings
def _find_max_depth(tree_parameters):
@ -188,7 +184,7 @@ def get_parameters_for_sklearn_common(tree_infos):
return TreeParameters(lefts, rights, features, thresholds, values)
def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values, missings=None, extra_config={}):
def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values, extra_config={}):
"""
Common functions used by all tree algorithms to generate the parameters according to the tree_trav strategies.
@ -198,7 +194,7 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val
features: The features used in the decision nodes
thresholds: The thresholds used in the decision nodes
values: The values stored in the leaf nodes
missings: In the case of a missing value which child node to select
Returns:
An array containing the extracted parameters
"""
@ -209,26 +205,18 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val
rights = [2, -1, -1]
features = [0, 0, 0]
thresholds = [0, 0, 0]
if missings is not None:
missings = [2, -1, -1]
n_classes = values.shape[1] if type(values) is np.ndarray else 1
values = np.array([np.zeros(n_classes), values[0], values[0]])
values.reshape(3, n_classes)
ids = [i for i in range(len(lefts))]
if missings is not None:
nodes = list(zip(ids, lefts, rights, features, thresholds, values, missings))
else:
nodes = list(zip(ids, lefts, rights, features, thresholds, values))
nodes = list(zip(ids, lefts, rights, features, thresholds, values))
# Refactor the tree parameters in the proper format.
nodes_map = {0: Node(0)}
current_node = 0
for i, node in enumerate(nodes):
if missings is not None:
id, left, right, feature, threshold, value, missing = node
else:
id, left, right, feature, threshold, value = node
id, left, right, feature, threshold, value = node
if left != -1:
l_node = Node(left)
@ -252,13 +240,6 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val
nodes_map[current_node].threshold = threshold
nodes_map[current_node].value = value
if missings is not None:
m_node = l_node if missing == left else r_node
nodes_map[current_node].missing = m_node
if missings[i] == -1:
missings[i] = id
current_node += 1
lefts = np.array(lefts)
@ -266,13 +247,11 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val
features = np.array(features)
thresholds = np.array(thresholds)
values = np.array(values)
if missings is not None:
missings = np.array(missings)
return [nodes_map, ids, lefts, rights, features, thresholds, values, missings]
return [nodes_map, ids, lefts, rights, features, thresholds, values]
def get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, missings=None, classes=None, extra_config={}):
def get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, classes=None, extra_config={}):
"""
This function is used to generate tree parameters for sklearn trees.
Includes SklearnRandomForestClassifier/Regressor, and SklearnGradientBoostingClassifier.
@ -283,7 +262,6 @@ def get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, va
features: The features used in the decision nodes
thresholds: The thresholds used in the decision nodes
values: The values stored in the leaf nodes
missings: In the case of a missing value which child node to select
classes: The list of class labels. None if regression model
Returns:
An array containing the extracted parameters
@ -298,10 +276,10 @@ def get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, va
if constants.NUM_TREES in extra_config:
values /= extra_config[constants.NUM_TREES]
return get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values, missings)
return get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values)
def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values, n_features, missings=None, extra_config={}):
def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values, n_features, extra_config={}):
"""
Common functions used by all tree algorithms to generate the parameters according to the GEMM strategy.
@ -312,7 +290,7 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values,
thresholds: The thresholds used in the decision nodes
values: The values stored in the leaf nodes
n_features: The number of expected input features
missings: In the case of a missing value which child node to select
Returns:
The weights and bias for the GEMM implementation
"""
@ -327,35 +305,20 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values,
rights = [2, -1, -1]
features = [0, 0, 0]
thresholds = [0, 0, 0]
if missings is not None:
missings = [2, -1, -1]
n_classes = values.shape[1]
values = np.array([np.zeros(n_classes), values[0], values[0]])
values.reshape(3, n_classes)
if missings is None:
missings = rights
# First hidden layer has all inequalities.
hidden_weights = []
hidden_biases = []
hidden_missing_biases = []
for left, right, missing, feature, thresh in zip(lefts, rights, missings, features, thresholds):
if left != -1 or right != -1:
for left, feature, thresh in zip(lefts, features, thresholds):
if left != -1:
hidden_weights.append([1 if i == feature else 0 for i in range(n_features)])
hidden_biases.append(thresh)
if missing == right:
hidden_missing_biases.append(1)
else:
hidden_missing_biases.append(0)
weights.append(np.array(hidden_weights).astype("float32"))
biases.append(np.array(hidden_biases).astype("float32"))
# Missing value handling biases.
weights.append(None)
biases.append(np.array(hidden_missing_biases).astype("float32"))
n_splits = len(hidden_weights)
# Second hidden layer has ANDs for each leaf of the decision tree.
@ -381,10 +344,10 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values,
for j, p in enumerate(path[:-1]):
num_leaves_before_p = list(lefts[:p]).count(-1)
if path[j + 1] in lefts:
vec[p - num_leaves_before_p] = -1
elif path[j + 1] in rights:
num_positive += 1
vec[p - num_leaves_before_p] = 1
num_positive += 1
elif path[j + 1] in rights:
vec[p - num_leaves_before_p] = -1
else:
raise RuntimeError("Inconsistent state encountered while tree translation.")
@ -433,7 +396,6 @@ def convert_decision_ensemble_tree_common(
tree_param.thresholds,
tree_param.values,
n_features,
tree_param.missings,
extra_config,
)
for tree_param in tree_parameters
@ -442,7 +404,7 @@ def convert_decision_ensemble_tree_common(
net_parameters = [
get_parameters_for_tree_trav(
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, tree_param.missings, extra_config,
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, extra_config,
)
for tree_param in tree_parameters
]

Просмотреть файл

@ -53,14 +53,13 @@ class AbstractPyTorchTreeImpl(AbstracTreeImpl, torch.nn.Module):
Abstract class definig the basic structure for tree-base models implemented in PyTorch.
"""
def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes, missing_val, **kwargs):
def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs):
"""
Args:
tree_parameters: The parameters defining the tree structure
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
n_classes: The total number of used classes
missing_val: The value to be treated as the missing value
"""
super(AbstractPyTorchTreeImpl, self).__init__(logical_operator, **kwargs)
@ -84,32 +83,23 @@ class AbstractPyTorchTreeImpl(AbstracTreeImpl, torch.nn.Module):
self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False)
self.perform_class_select = True
self.missing_val = missing_val
if self.missing_val in [None, np.nan]:
self.missing_val_op = torch.isnan
else:
def missing_val_op(x):
return x == self.missing_val
self.missing_val_op = missing_val_op
class GEMMTreeImpl(AbstractPyTorchTreeImpl):
"""
Class implementing the GEMM strategy in PyTorch for tree-base models.
"""
def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes=None, missing_val=None, extra_config={}, **kwargs):
def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes=None, extra_config={}, **kwargs):
"""
Args:
tree_parameters: The parameters defining the tree structure
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
n_classes: The total number of used classes
missing_val: The value to be treated as the missing value
"""
# If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes.
n_classes = n_classes if n_classes is not None else tree_parameters[0][0][3].shape[0]
super(GEMMTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, missing_val, **kwargs)
n_classes = n_classes if n_classes is not None else tree_parameters[0][0][2].shape[0]
super(GEMMTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs)
# Initialize the actual model.
hidden_one_size = 0
@ -118,24 +108,22 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
for weight, bias in tree_parameters:
hidden_one_size = max(hidden_one_size, weight[0].shape[0])
hidden_two_size = max(hidden_two_size, weight[2].shape[0])
hidden_two_size = max(hidden_two_size, weight[1].shape[0])
n_trees = len(tree_parameters)
weight_1 = np.zeros((n_trees, hidden_one_size))
weight_1 = np.zeros((n_trees, hidden_one_size, n_features))
bias_1 = np.zeros((n_trees, hidden_one_size))
missing_bias_1 = np.zeros((n_trees, hidden_one_size))
weight_2 = np.zeros((n_trees, hidden_two_size, hidden_one_size))
bias_2 = np.zeros((n_trees, hidden_two_size))
weight_3 = np.zeros((n_trees, hidden_three_size, hidden_two_size))
for i, (weight, bias) in enumerate(tree_parameters):
if len(weight[0]) > 0:
weight_1[i, 0 : weight[0].shape[0]] = np.argmax(weight[0], axis=1)
weight_1[i, 0 : weight[0].shape[0], 0 : weight[0].shape[1]] = weight[0]
bias_1[i, 0 : bias[0].shape[0]] = bias[0]
missing_bias_1[i, 0 : bias[1].shape[0]] = bias[1]
weight_2[i, 0 : weight[2].shape[0], 0 : weight[2].shape[1]] = weight[2]
bias_2[i, 0 : bias[2].shape[0]] = bias[2]
weight_3[i, 0 : weight[3].shape[0], 0 : weight[3].shape[1]] = weight[3]
weight_2[i, 0 : weight[1].shape[0], 0 : weight[1].shape[1]] = weight[1]
bias_2[i, 0 : bias[1].shape[0]] = bias[1]
weight_3[i, 0 : weight[2].shape[0], 0 : weight[2].shape[1]] = weight[2]
self.n_trees = n_trees
self.n_features = n_features
@ -143,19 +131,13 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
self.hidden_two_size = hidden_two_size
self.hidden_three_size = hidden_three_size
self.weight_1 = torch.nn.Parameter(torch.from_numpy(weight_1.reshape(-1).astype("int64")), requires_grad=False)
self.bias_1 = torch.nn.Parameter(torch.from_numpy(bias_1.reshape(1, -1).astype("float32")), requires_grad=False)
self.weight_1 = torch.nn.Parameter(torch.from_numpy(weight_1.reshape(-1, self.n_features).astype("float32")))
self.bias_1 = torch.nn.Parameter(torch.from_numpy(bias_1.reshape(-1, 1).astype("float32")))
# By default when we compare nan to any value the output will be false. Thus we need to explicitly
# account for missing values only when missings are different to lefts (i.e., False condition)
if np.sum(missing_bias_1) != 0:
self.missing_bias_1 = torch.nn.Parameter(torch.from_numpy(missing_bias_1.reshape(1, -1).astype("float32")), requires_grad=False)
else:
self.missing_bias_1 = None
self.weight_2 = torch.nn.Parameter(torch.from_numpy(weight_2.astype("float32")))
self.bias_2 = torch.nn.Parameter(torch.from_numpy(bias_2.reshape(-1, 1).astype("float32")))
self.weight_2 = torch.nn.Parameter(torch.from_numpy(weight_2.astype("float32")), requires_grad=False)
self.bias_2 = torch.nn.Parameter(torch.from_numpy(bias_2.reshape(-1, 1).astype("float32")), requires_grad=False)
self.weight_3 = torch.nn.Parameter(torch.from_numpy(weight_3.astype("float32")), requires_grad=False)
self.weight_3 = torch.nn.Parameter(torch.from_numpy(weight_3.astype("float32")))
# We register also base_prediction here so that tensor will be moved to the proper hardware with the model.
# i.e., if cuda is selected, the parameter will be automatically moved on the GPU.
@ -166,12 +148,10 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
return x
def forward(self, x):
features = torch.index_select(x, 1, self.weight_1)
if self.missing_bias_1 is not None:
x = torch.where(self.missing_val_op(features), self.missing_bias_1 + torch.zeros_like(features), (features >= self.bias_1).float())
else:
x = (features >= self.bias_1).float()
x = x.view(-1, self.n_trees * self.hidden_one_size).t().view(self.n_trees, self.hidden_one_size, -1)
x = x.t()
x = torch.mm(self.weight_1, x) < self.bias_1
x = x.view(self.n_trees, self.hidden_one_size, -1)
x = x.float()
x = torch.matmul(self.weight_2, x)
@ -207,7 +187,9 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
indexes = indexes.expand(batch_size, self.num_trees)
return indexes.reshape(-1)
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes, n_classes=None, missing_val=None, extra_config={}, **kwargs):
def __init__(
self, logical_operator, tree_parameters, max_depth, n_features, classes, n_classes=None, extra_config={}, **kwargs
):
"""
Args:
tree_parameters: The parameters defining the tree structure
@ -215,12 +197,13 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
n_classes: The total number of used classes
missing_val: The value to be treated as the missing value
extra_config: Extra configuration used to properly implement the source tree
"""
# If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes.
n_classes = n_classes if n_classes is not None else tree_parameters[0][6].shape[1]
super(TreeTraversalTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, missing_val, **kwargs)
super(TreeTraversalTreeImpl, self).__init__(
logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs
)
# Initialize the actual model.
self.n_features = n_features
@ -235,28 +218,16 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
thresholds = np.zeros((self.num_trees, self.num_nodes), dtype=np.float32)
values = np.zeros((self.num_trees, self.num_nodes, self.n_classes), dtype=np.float32)
missings = None
if len(tree_parameters[0]) == 8 and tree_parameters[0][7] is not None:
missings = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64)
for i in range(self.num_trees):
lefts[i][: len(tree_parameters[i][0])] = tree_parameters[i][2]
rights[i][: len(tree_parameters[i][0])] = tree_parameters[i][3]
features[i][: len(tree_parameters[i][0])] = tree_parameters[i][4]
thresholds[i][: len(tree_parameters[i][0])] = tree_parameters[i][5]
values[i][: len(tree_parameters[i][0])][:] = tree_parameters[i][6]
if missings is not None:
missings[i][: len(tree_parameters[i][0])] = tree_parameters[i][7]
self.lefts = torch.nn.Parameter(torch.from_numpy(lefts).view(-1), requires_grad=False)
self.rights = torch.nn.Parameter(torch.from_numpy(rights).view(-1), requires_grad=False)
# By default when we compare nan to any value the output will be false. Thus we need to explicitly
# account for missing values when missings are different to lefts (i.e., false condition)
self.missings = None
if missings is not None and not np.allclose(lefts, missings):
self.missings = torch.nn.Parameter(torch.from_numpy(missings).view(-1), requires_grad=False)
self.features = torch.nn.Parameter(torch.from_numpy(features).view(-1), requires_grad=False)
self.thresholds = torch.nn.Parameter(torch.from_numpy(thresholds).view(-1))
self.values = torch.nn.Parameter(torch.from_numpy(values).view(-1, self.n_classes))
@ -284,12 +255,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
lefts = torch.index_select(self.lefts, 0, indexes).view(-1, self.num_trees)
rights = torch.index_select(self.rights, 0, indexes).view(-1, self.num_trees)
if self.missings is not None:
missings = torch.index_select(self.missings, 0, indexes).view(-1, self.num_trees)
indexes = torch.where(torch.ge(feature_values, thresholds), rights, lefts).long()
if self.missings is not None:
indexes = torch.where(self.missing_val_op(feature_values), missings, indexes)
indexes = indexes + self.nodes_offset
indexes = indexes.view(-1)
@ -315,19 +281,22 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
Class implementing the Perfect Tree Traversal strategy in PyTorch for tree-base models.
"""
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes, n_classes=None, missing_val=None, extra_config={}, **kwargs):
def __init__(
self, logical_operator, tree_parameters, max_depth, n_features, classes, n_classes=None, extra_config={}, **kwargs
):
"""
Args:
tree_parameters: The parameters defining the tree structure
max_depth: The maximum tree-depth in the model
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
missing_val: The value to be treated as the missing value
n_classes: The total number of used classes
"""
# If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes.
n_classes = n_classes if n_classes is not None else tree_parameters[0][6].shape[1]
super(PerfectTreeTraversalTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, missing_val, **kwargs)
super(PerfectTreeTraversalTreeImpl, self).__init__(
logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs
)
# Initialize the actual model.
self.max_tree_depth = max_depth
@ -336,53 +305,40 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
node_maps = [tp[0] for tp in tree_parameters]
feature_ids = np.zeros((self.num_trees, 2 ** max_depth - 1))
threshold_vals = np.zeros((self.num_trees, 2 ** max_depth - 1))
leaf_vals = np.zeros((self.num_trees, 2 ** max_depth, self.n_classes))
missings = np.zeros((self.num_trees, 2 ** max_depth - 1), dtype=np.int64)
weight_0 = np.zeros((self.num_trees, 2 ** max_depth - 1))
bias_0 = np.zeros((self.num_trees, 2 ** max_depth - 1))
weight_1 = np.zeros((self.num_trees, 2 ** max_depth, self.n_classes))
# By default when we compare nan to any value the output will be false. Thus, in `_populate_structure_tensors` we check whether there are
# non-trivial missings that are different to lefts (i.e., false condition) and set the `self.has_non_trivial_missing_vals` to True.
self.has_non_trivial_missing_vals = False
for i, node_map in enumerate(node_maps):
self._populate_structure_tensors(node_map, max_depth, feature_ids[i], leaf_vals[i], threshold_vals[i], missings[i])
self._get_weights_and_biases(node_map, max_depth, weight_0[i], weight_1[i], bias_0[i])
node_by_levels = [set() for _ in range(max_depth)]
self._traverse_by_level(node_by_levels, 0, -1, max_depth)
self.root_nodes = torch.nn.Parameter(torch.from_numpy(feature_ids[:, 0].flatten().astype("int64")), requires_grad=False)
self.root_biases = torch.nn.Parameter(-1 * torch.from_numpy(threshold_vals[:, 0].astype("float32")), requires_grad=False)
self.root_missing_node_ids = torch.nn.Parameter(torch.from_numpy(missings[:, 0].astype("int64")), requires_grad=False)
self.root_nodes = torch.nn.Parameter(torch.from_numpy(weight_0[:, 0].flatten().astype("int64")), requires_grad=False)
self.root_biases = torch.nn.Parameter(-1 * torch.from_numpy(bias_0[:, 0].astype("float32")), requires_grad=False)
tree_indices = np.array([i for i in range(0, 2 * self.num_trees, 2)]).astype("int64")
self.tree_indices = torch.nn.Parameter(torch.from_numpy(tree_indices), requires_grad=False)
self.feature_ids = []
self.threshold_vals = []
self.missing_node_ids = []
self.nodes = []
self.biases = []
for i in range(1, max_depth):
features = torch.nn.Parameter(
torch.from_numpy(feature_ids[:, list(sorted(node_by_levels[i]))].flatten().astype("int64")), requires_grad=False
nodes = torch.nn.Parameter(
torch.from_numpy(weight_0[:, list(sorted(node_by_levels[i]))].flatten().astype("int64")), requires_grad=False
)
thresholds = torch.nn.Parameter(
torch.from_numpy(-1 * threshold_vals[:, list(sorted(node_by_levels[i]))].flatten().astype("float32")),
requires_grad=False,
)
missing_nodes = torch.nn.Parameter(
torch.from_numpy(missings[:, list(sorted(node_by_levels[i]))].flatten().astype("int64")),
biases = torch.nn.Parameter(
torch.from_numpy(-1 * bias_0[:, list(sorted(node_by_levels[i]))].flatten().astype("float32")),
requires_grad=False,
)
self.nodes.append(nodes)
self.biases.append(biases)
self.feature_ids.append(features)
self.threshold_vals.append(thresholds)
self.missing_node_ids.append(missing_nodes)
self.feature_ids = torch.nn.ParameterList(self.feature_ids)
self.threshold_vals = torch.nn.ParameterList(self.threshold_vals)
self.missing_node_ids = torch.nn.ParameterList(self.missing_node_ids)
self.nodes = torch.nn.ParameterList(self.nodes)
self.biases = torch.nn.ParameterList(self.biases)
self.leaf_nodes = torch.nn.Parameter(
torch.from_numpy(leaf_vals.reshape((-1, self.n_classes)).astype("float32")), requires_grad=False
torch.from_numpy(weight_1.reshape((-1, self.n_classes)).astype("float32")), requires_grad=False
)
# We register also base_prediction here so that tensor will be moved to the proper hardware with the model.
@ -394,23 +350,15 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
return x
def forward(self, x):
root_features = torch.index_select(x, 1, self.root_nodes)
prev_indices = torch.ge(root_features, self.root_biases).long()
if self.has_non_trivial_missing_vals:
prev_indices = torch.where(self.missing_val_op(root_features), self.root_missing_node_ids, prev_indices)
prev_indices = (torch.ge(torch.index_select(x, 1, self.root_nodes), self.root_biases)).long()
prev_indices = prev_indices + self.tree_indices
prev_indices = prev_indices.view(-1)
factor = 2
for features, thresholds, missings in zip(self.feature_ids, self.threshold_vals, self.missing_node_ids):
gather_indices = torch.index_select(features, 0, prev_indices).view(-1, self.num_trees)
for nodes, biases in zip(self.nodes, self.biases):
gather_indices = torch.index_select(nodes, 0, prev_indices).view(-1, self.num_trees)
features = torch.gather(x, 1, gather_indices).view(-1)
thresholds = torch.index_select(thresholds, 0, prev_indices)
node_eval_status = torch.ge(features, thresholds).long()
if self.has_non_trivial_missing_vals:
missings = torch.index_select(missings, 0, prev_indices)
node_eval_status = torch.where(self.missing_val_op(features), missings, node_eval_status)
prev_indices = factor * prev_indices + node_eval_status
prev_indices = factor * prev_indices + torch.ge(features, torch.index_select(biases, 0, prev_indices)).long()
output = torch.index_select(self.leaf_nodes, 0, prev_indices).view(-1, self.num_trees, self.n_classes)
@ -438,24 +386,17 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
node_id = self._traverse_by_level(node_by_levels, node_id, current_level, max_level)
return node_id
def _populate_structure_tensors(self, nodes_map, tree_depth, node_ids, threshold_vals, leaf_vals, missings):
def _get_weights_and_biases(self, nodes_map, tree_depth, weight_0, weight_1, bias_0):
def depth_f_traversal(node, current_depth, node_id, leaf_start_id):
node_ids[node_id] = node.feature
leaf_vals[node_id] = -node.threshold
if node.missing is None or node.left == node.missing:
missings[node_id] = 0
else:
missings[node_id] = 1
self.has_non_trivial_missing_vals = True
weight_0[node_id] = node.feature
bias_0[node_id] = -node.threshold
current_depth += 1
node_id += 1
if node.left.feature == -1:
node_id += 2 ** (tree_depth - current_depth - 1) - 1
v = node.left.value
threshold_vals[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = (
weight_1[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = (
np.ones((2 ** (tree_depth - current_depth - 1), self.n_classes)) * v
)
leaf_start_id += 2 ** (tree_depth - current_depth - 1)
@ -465,7 +406,7 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
if node.right.feature == -1:
node_id += 2 ** (tree_depth - current_depth - 1) - 1
v = node.right.value
threshold_vals[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = (
weight_1[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = (
np.ones((2 ** (tree_depth - current_depth - 1), self.n_classes)) * v
)
leaf_start_id += 2 ** (tree_depth - current_depth - 1)
@ -484,15 +425,14 @@ class GEMMDecisionTreeImpl(GEMMTreeImpl):
"""
def __init__(self, logical_operator, tree_parameters, n_features, classes=None, missing_val=None):
def __init__(self, logical_operator, tree_parameters, n_features, classes=None):
"""
Args:
tree_parameters: The parameters defining the tree structure
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
missing_val: The value to be treated as the missing value
"""
super(GEMMDecisionTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, missing_val)
super(GEMMDecisionTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes)
def aggregation(self, x):
output = x.sum(0).t()
@ -505,14 +445,13 @@ class TreeTraversalDecisionTreeImpl(TreeTraversalTreeImpl):
Class implementing the Tree Traversal strategy in PyTorch for decision tree models.
"""
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, missing_val=None, extra_config={}):
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}):
"""
Args:
tree_parameters: The parameters defining the tree structure
max_depth: The maximum tree-depth in the model
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
missing_val: The value to be treated as the missing value
extra_config: Extra configuration used to properly implement the source tree
"""
super(TreeTraversalDecisionTreeImpl, self).__init__(
@ -530,16 +469,17 @@ class PerfectTreeTraversalDecisionTreeImpl(PerfectTreeTraversalTreeImpl):
Class implementing the Perfect Tree Traversal strategy in PyTorch for decision tree models.
"""
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, missing_val=None):
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None):
"""
Args:
tree_parameters: The parameters defining the tree structure
max_depth: The maximum tree-depth in the model
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
missing_val: The value to be treated as the missing value
"""
super(PerfectTreeTraversalDecisionTreeImpl, self).__init__(logical_operator, tree_parameters, max_depth, n_features, classes, missing_val)
super(PerfectTreeTraversalDecisionTreeImpl, self).__init__(
logical_operator, tree_parameters, max_depth, n_features, classes
)
def aggregation(self, x):
output = x.sum(1)
@ -553,16 +493,15 @@ class GEMMGBDTImpl(GEMMTreeImpl):
Class implementing the GEMM strategy (in PyTorch) for GBDT models.
"""
def __init__(self, logical_operator, tree_parameters, n_features, classes=None, missing_val=None, extra_config={}):
def __init__(self, logical_operator, tree_parameters, n_features, classes=None, extra_config={}):
"""
Args:
tree_parameters: The parameters defining the tree structure
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
missing_val: The value to be treated as the missing value
extra_config: Extra configuration used to properly implement the source tree
"""
super(GEMMGBDTImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, 1, missing_val, extra_config)
super(GEMMGBDTImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, 1, extra_config)
self.n_gbdt_classes = 1
self.post_transform = lambda x: x
@ -586,17 +525,18 @@ class TreeTraversalGBDTImpl(TreeTraversalTreeImpl):
Class implementing the Tree Traversal strategy in PyTorch.
"""
def __init__(self, logical_operator, tree_parameters, max_detph, n_features, classes=None, missing_val=None, extra_config={}):
def __init__(self, logical_operator, tree_parameters, max_detph, n_features, classes=None, extra_config={}):
"""
Args:
tree_parameters: The parameters defining the tree structure
max_depth: The maximum tree-depth in the model
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
missing_val: The value to be treated as the missing_val value
extra_config: Extra configuration used to properly implement the source tree
"""
super(TreeTraversalGBDTImpl, self).__init__(logical_operator, tree_parameters, max_detph, n_features, classes, 1, missing_val, extra_config)
super(TreeTraversalGBDTImpl, self).__init__(
logical_operator, tree_parameters, max_detph, n_features, classes, 1, extra_config
)
self.n_gbdt_classes = 1
self.post_transform = lambda x: x
@ -620,17 +560,18 @@ class PerfectTreeTraversalGBDTImpl(PerfectTreeTraversalTreeImpl):
Class implementing the Perfect Tree Traversal strategy in PyTorch.
"""
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, missing_val=None, extra_config={}):
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}):
"""
Args:
tree_parameters: The parameters defining the tree structure
max_depth: The maximum tree-depth in the model
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
missing_val: The value to be treated as the missing_val value
extra_config: Extra configuration used to properly implement the source tree
"""
super(PerfectTreeTraversalGBDTImpl, self).__init__(logical_operator, tree_parameters, max_depth, n_features, classes, 1, missing_val, extra_config)
super(PerfectTreeTraversalGBDTImpl, self).__init__(
logical_operator, tree_parameters, max_depth, n_features, classes, 1, extra_config
)
self.n_gbdt_classes = 1
self.post_transform = lambda x: x

Просмотреть файл

@ -16,7 +16,7 @@ from ._gbdt_commons import convert_gbdt_classifier_common, convert_gbdt_common
from ._tree_commons import TreeParameters
def _tree_traversal(node, lefts, rights, features, thresholds, values, missings, count):
def _tree_traversal(node, lefts, rights, features, thresholds, values, count):
"""
Recursive function for parsing a tree and filling the input data structures.
"""
@ -26,23 +26,16 @@ def _tree_traversal(node, lefts, rights, features, thresholds, values, missings,
values.append([-1])
lefts.append(count + 1)
rights.append(-1)
missings.append(-1)
pos = len(rights) - 1
count = _tree_traversal(node["left_child"], lefts, rights, features, thresholds, values, missings, count + 1)
count = _tree_traversal(node["left_child"], lefts, rights, features, thresholds, values, count + 1)
rights[pos] = count + 1
if node['missing_type'] == 'None':
# Missing values not present in training data are treated as zeros during inference.
missings[pos] = lefts[pos] if 0 < node["threshold"] else rights[pos]
else:
missings[pos] = lefts[pos] if node["default_left"] else rights[pos]
return _tree_traversal(node["right_child"], lefts, rights, features, thresholds, values, missings, count + 1)
return _tree_traversal(node["right_child"], lefts, rights, features, thresholds, values, count + 1)
else:
features.append(0)
thresholds.append(0)
values.append([node["leaf_value"]])
lefts.append(-1)
rights.append(-1)
missings.append(-1)
return count
@ -55,10 +48,9 @@ def _get_tree_parameters(tree_info):
features = []
thresholds = []
values = []
missings = []
_tree_traversal(tree_info["tree_structure"], lefts, rights, features, thresholds, values, missings, 0)
_tree_traversal(tree_info["tree_structure"], lefts, rights, features, thresholds, values, 0)
return TreeParameters(lefts, rights, features, thresholds, values, missings)
return TreeParameters(lefts, rights, features, thresholds, values)
def convert_sklearn_lgbm_classifier(operator, device, extra_config):
@ -74,14 +66,14 @@ def convert_sklearn_lgbm_classifier(operator, device, extra_config):
A PyTorch model
"""
assert operator is not None, "Cannot convert None operator"
assert not hasattr(operator.raw_operator, "use_missing") or operator.raw_operator.use_missing
assert not hasattr(operator.raw_operator, "zero_as_missing") or not operator.raw_operator.zero_as_missing
n_features = operator.raw_operator._n_features
tree_infos = operator.raw_operator.booster_.dump_model()["tree_info"]
n_classes = operator.raw_operator._n_classes
return convert_gbdt_classifier_common(operator, tree_infos, _get_tree_parameters, n_features, n_classes, missing_val=None, extra_config=extra_config)
return convert_gbdt_classifier_common(
operator, tree_infos, _get_tree_parameters, n_features, n_classes, extra_config=extra_config
)
def convert_sklearn_lgbm_regressor(operator, device, extra_config):
@ -97,8 +89,6 @@ def convert_sklearn_lgbm_regressor(operator, device, extra_config):
A PyTorch model
"""
assert operator is not None, "Cannot convert None operator"
assert not hasattr(operator.raw_operator, "use_missing") or operator.raw_operator.use_missing
assert not hasattr(operator.raw_operator, "zero_as_missing") or not operator.raw_operator.zero_as_missing
# Get tree information out of the model.
n_features = operator.raw_operator._n_features
@ -106,7 +96,7 @@ def convert_sklearn_lgbm_regressor(operator, device, extra_config):
if operator.raw_operator._objective == "tweedie":
extra_config[constants.POST_TRANSFORM] = constants.TWEEDIE
return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, missing_val=None, extra_config=extra_config)
return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, extra_config=extra_config)
# Register the converters.

Просмотреть файл

@ -214,14 +214,7 @@ def convert_onnx_tree_ensemble_classifier(operator, device=None, extra_config={}
)
extra_config[constants.POST_TRANSFORM] = post_transform
return convert_gbdt_classifier_common(
operator,
tree_infos,
_dummy_get_parameter,
n_features,
len(classes),
classes,
missing_val=None,
extra_config=extra_config,
operator, tree_infos, _dummy_get_parameter, n_features, len(classes), classes, extra_config
)
@ -243,9 +236,7 @@ def convert_onnx_tree_ensemble_regressor(operator, device=None, extra_config={})
n_features, tree_infos, _, _ = _get_tree_infos_from_tree_ensemble(operator.raw_operator, device, extra_config)
# Generate the model.
return convert_gbdt_common(
operator, tree_infos, _dummy_get_parameter, n_features, missing_val=None, extra_config=extra_config
)
return convert_gbdt_common(operator, tree_infos, _dummy_get_parameter, n_features, extra_config=extra_config)
register_converter("ONNXMLTreeEnsembleClassifier", convert_onnx_tree_ensemble_classifier)

Просмотреть файл

@ -42,8 +42,8 @@ def convert_sklearn_random_forest_classifier(operator, device, extra_config):
if not all(isinstance(c, int) for c in classes):
raise RuntimeError("Random Forest Classifier translation only supports integer class labels")
def get_parameters_for_tree_trav(lefts, rights, features, thresholds, values, missings=None, extra_config={}):
return get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, missings, classes, extra_config)
def get_parameters_for_tree_trav(lefts, rights, features, thresholds, values, extra_config={}):
return get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, classes, extra_config)
return convert_decision_ensemble_tree_common(
operator,
@ -77,8 +77,8 @@ def convert_sklearn_random_forest_regressor(operator, device, extra_config):
# For Sklearn Trees we need to know how many trees are there for normalization.
extra_config[constants.NUM_TREES] = len(tree_infos)
def get_parameters_for_tree_trav(lefts, rights, features, thresholds, values, missings=None, extra_config={}):
return get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, missings, None, extra_config)
def get_parameters_for_tree_trav(lefts, rights, features, thresholds, values, extra_config={}):
return get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, None, extra_config)
return convert_decision_ensemble_tree_common(
operator,

Просмотреть файл

@ -90,7 +90,7 @@ def convert_sklearn_gbdt_classifier(operator, device, extra_config):
extra_config[constants.REORDER_TREES] = False
return convert_gbdt_classifier_common(
operator, tree_infos, get_parameters_for_sklearn_common, n_features, n_classes, classes, missing_val=None, extra_config=extra_config
operator, tree_infos, get_parameters_for_sklearn_common, n_features, n_classes, classes, extra_config
)
@ -126,7 +126,7 @@ def convert_sklearn_gbdt_regressor(operator, device, extra_config):
extra_config[constants.BASE_PREDICTION] = base_prediction
return convert_gbdt_common(operator, tree_infos, get_parameters_for_sklearn_common, n_features, missing_val=None, extra_config=extra_config)
return convert_gbdt_common(operator, tree_infos, get_parameters_for_sklearn_common, n_features, None, extra_config)
def convert_sklearn_hist_gbdt_classifier(operator, device, extra_config):
@ -167,7 +167,9 @@ def convert_sklearn_hist_gbdt_classifier(operator, device, extra_config):
extra_config[constants.BASE_PREDICTION] = base_prediction
extra_config[constants.REORDER_TREES] = False
return convert_gbdt_classifier_common(operator, tree_infos, _get_parameters_hist_gbdt, n_features, n_classes, classes, missing_val=None, extra_config=extra_config)
return convert_gbdt_classifier_common(
operator, tree_infos, _get_parameters_hist_gbdt, n_features, n_classes, classes, extra_config
)
def convert_sklearn_hist_gbdt_regressor(operator, device, extra_config):
@ -190,7 +192,7 @@ def convert_sklearn_hist_gbdt_regressor(operator, device, extra_config):
n_features = operator.raw_operator.n_features_
extra_config[constants.BASE_PREDICTION] = [[operator.raw_operator._baseline_prediction]]
return convert_gbdt_common(operator, tree_infos, _get_parameters_hist_gbdt, n_features, missing_val=None, extra_config=extra_config)
return convert_gbdt_common(operator, tree_infos, _get_parameters_hist_gbdt, n_features, None, extra_config)
# Register the converters.

Просмотреть файл

@ -239,7 +239,7 @@ def convert_sklearn_isolation_forest(operator, device, extra_config):
if tree_type == TreeImpl.gemm:
net_parameters = [
get_parameters_for_gemm_common(
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, n_features, tree_param.missings
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, n_features
)
for tree_param in tree_parameters
]
@ -247,7 +247,7 @@ def convert_sklearn_isolation_forest(operator, device, extra_config):
net_parameters = [
get_parameters_for_tree_trav_sklearn(
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, tree_param.missings, classes
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, classes
)
for tree_param in tree_parameters
]

Просмотреть файл

@ -16,7 +16,7 @@ from ._gbdt_commons import convert_gbdt_classifier_common, convert_gbdt_common
from ._tree_commons import TreeParameters
def _tree_traversal(tree_info, lefts, rights, missing, features, thresholds, values):
def _tree_traversal(tree_info, lefts, rights, features, thresholds, values):
"""
Recursive function for parsing a tree and filling the input data structures.
"""
@ -28,7 +28,6 @@ def _tree_traversal(tree_info, lefts, rights, missing, features, thresholds, val
values.append([float(tree_info[count].split("=")[1])])
lefts.append(-1)
rights.append(-1)
missing.append(-1)
count += 1
else:
features.append(int(tree_info[count].split(":")[1].split("<")[0].replace("[f", "")))
@ -57,8 +56,6 @@ def _tree_traversal(tree_info, lefts, rights, missing, features, thresholds, val
r_correct_id += 1
rights.append(r_correct_id)
missing_wrong_id = tree_info[count].split(",")[2].replace("missing=", "")
missing.append(l_correct_id if l_wrong_id == missing_wrong_id else r_correct_id)
count += 1
@ -68,15 +65,14 @@ def _get_tree_parameters(tree_info):
"""
lefts = []
rights = []
missing = []
features = []
thresholds = []
values = []
_tree_traversal(
tree_info.replace("[f", "").replace("[", "").replace("]", "").split(), lefts, rights, missing, features, thresholds, values
tree_info.replace("[f", "").replace("[", "").replace("]", "").split(), lefts, rights, features, thresholds, values
)
return TreeParameters(lefts, rights, features, thresholds, values, missing)
return TreeParameters(lefts, rights, features, thresholds, values)
def convert_sklearn_xgb_classifier(operator, device, extra_config):
@ -101,9 +97,10 @@ def convert_sklearn_xgb_classifier(operator, device, extra_config):
)
tree_infos = operator.raw_operator.get_booster().get_dump()
n_classes = operator.raw_operator.n_classes_
missing_val = operator.raw_operator.missing
return convert_gbdt_classifier_common(operator, tree_infos, _get_tree_parameters, n_features, n_classes, missing_val=missing_val, extra_config=extra_config)
return convert_gbdt_classifier_common(
operator, tree_infos, _get_tree_parameters, n_features, n_classes, extra_config=extra_config
)
def convert_sklearn_xgb_regressor(operator, device, extra_config):
@ -135,9 +132,8 @@ def convert_sklearn_xgb_regressor(operator, device, extra_config):
base_prediction = [base_prediction]
extra_config[constants.BASE_PREDICTION] = base_prediction
missing_val = operator.raw_operator.missing
return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, missing_val=missing_val, extra_config=extra_config)
return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, extra_config=extra_config)
# Register the converters.

Просмотреть файл

@ -9,7 +9,6 @@ import numpy as np
import hummingbird.ml
from hummingbird.ml._utils import lightgbm_installed, onnx_runtime_installed, tvm_installed
from tree_utils import gbdt_implementation_map
from sklearn.datasets import make_classification, make_regression
if lightgbm_installed():
import lightgbm as lgb
@ -259,43 +258,6 @@ class TestLGBMConverter(unittest.TestCase):
self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
# Missing values test.
@unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed")
def test_run_lgbm_classifier_w_missing_vals_converter(self):
warnings.filterwarnings("ignore")
for extra_config_param in ["gemm", "tree_trav", "perf_tree_trav"]:
for missing in [None, np.nan]:
for model_class, n_classes in zip([lgb.LGBMClassifier, lgb.LGBMClassifier, lgb.LGBMRegressor], [2, 3, None]):
model = model_class(use_missing=True, zero_as_missing=False)
# Missing values during training + inference.
if model_class == lgb.LGBMClassifier:
X, y = make_classification(n_samples=100, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021)
else:
X, y = make_regression(n_samples=100, n_features=3, n_informative=3, random_state=2021)
X[:25][y[:25] == 0, 0] = np.nan if missing is None else missing
model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch", X, extra_config={"tree_implementation": extra_config_param})
self.assertIsNotNone(torch_model)
if model_class == lgb.LGBMClassifier:
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
else:
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
# Missing values during only inference.
model = model_class(use_missing=True, zero_as_missing=False)
if model_class == lgb.LGBMClassifier:
X, y = make_classification(n_samples=100, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021)
else:
X, y = make_regression(n_samples=100, n_features=3, n_informative=3, random_state=2021)
model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch", X, extra_config={"tree_implementation": extra_config_param})
X[:25][y[:25] == 0, 0] = np.nan if missing is None else missing
self.assertIsNotNone(torch_model)
if model_class == lgb.LGBMClassifier:
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
else:
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
# Backend tests.
# Test TorchScript backend regression.
@unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed")

Просмотреть файл

@ -10,7 +10,6 @@ import hummingbird.ml
from hummingbird.ml._utils import xgboost_installed, tvm_installed
from hummingbird.ml import constants
from tree_utils import gbdt_implementation_map
from sklearn.datasets import make_classification, make_regression
if xgboost_installed():
import xgboost as xgb
@ -225,43 +224,6 @@ class TestXGBoostConverter(unittest.TestCase):
self.assertIsNotNone(torch_model)
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
# Missing values test.
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
def test_run_xgb_classifier_w_missing_vals_converter(self):
warnings.filterwarnings("ignore")
for extra_config_param in ["gemm", "tree_trav", "perf_tree_trav"]:
for missing in [None, -99999, np.nan]:
for model_class, n_classes in zip([xgb.XGBClassifier, xgb.XGBClassifier, xgb.XGBRegressor], [2, 3, None]):
model = model_class(missing=missing)
# Missing values during both training and inference.
if model_class == xgb.XGBClassifier:
X, y = make_classification(n_samples=100, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021)
else:
X, y = make_regression(n_samples=100, n_features=3, n_informative=3, random_state=2021)
X[:25][y[:25] == 0, 0] = np.nan if missing is None else missing
model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch", X, extra_config={"tree_implementation": extra_config_param})
self.assertIsNotNone(torch_model)
if model_class == xgb.XGBClassifier:
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
else:
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
# Missing values during only inference.
model = model_class(missing=missing)
if model_class == xgb.XGBClassifier:
X, y = make_classification(n_samples=100, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021)
else:
X, y = make_regression(n_samples=100, n_features=3, n_informative=3, random_state=2021)
model.fit(X, y)
X[:25][y[:25] == 0, 0] = np.nan if missing is None else missing
torch_model = hummingbird.ml.convert(model, "torch", X, extra_config={"tree_implementation": extra_config_param})
self.assertIsNotNone(torch_model)
if model_class == xgb.XGBClassifier:
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
else:
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
# Torchscript backends.
# Test TorchScript backend regression.
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")