Родитель
b88b6f7e02
Коммит
c525c51060
|
@ -15,7 +15,9 @@ from ._tree_commons import get_tree_params_and_type, get_parameters_for_tree_tra
|
|||
from ._tree_implementations import GEMMGBDTImpl, TreeTraversalGBDTImpl, PerfectTreeTraversalGBDTImpl, TreeImpl
|
||||
|
||||
|
||||
def convert_gbdt_classifier_common(operator, tree_infos, get_tree_parameters, n_features, n_classes, classes=None, missing_val=None, extra_config={}):
|
||||
def convert_gbdt_classifier_common(
|
||||
operator, tree_infos, get_tree_parameters, n_features, n_classes, classes=None, extra_config={}
|
||||
):
|
||||
"""
|
||||
Common converter for GBDT classifiers.
|
||||
|
||||
|
@ -25,7 +27,6 @@ def convert_gbdt_classifier_common(operator, tree_infos, get_tree_parameters, n_
|
|||
n_features: The number of features input to the model
|
||||
n_classes: How many classes are expected. 1 for regression tasks
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
missing_val: The value to be treated as the missing value
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
|
||||
Returns:
|
||||
|
@ -47,10 +48,10 @@ def convert_gbdt_classifier_common(operator, tree_infos, get_tree_parameters, n_
|
|||
if reorder_trees and n_classes > 1:
|
||||
tree_infos = [tree_infos[i * n_classes + j] for j in range(n_classes) for i in range(len(tree_infos) // n_classes)]
|
||||
|
||||
return convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes, missing_val, extra_config)
|
||||
return convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes, extra_config)
|
||||
|
||||
|
||||
def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes=None, missing_val=None, extra_config={}):
|
||||
def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes=None, extra_config={}):
|
||||
"""
|
||||
Common converter for GBDT models.
|
||||
|
||||
|
@ -59,7 +60,6 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c
|
|||
get_tree_parameters: A function specifying how to parse the tree_infos into parameters
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
missing_val: The value to be treated as the missing value
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
|
||||
Returns:
|
||||
|
@ -86,7 +86,6 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c
|
|||
tree_param.thresholds,
|
||||
tree_param.values,
|
||||
n_features,
|
||||
tree_param.missings,
|
||||
extra_config,
|
||||
)
|
||||
for tree_param in tree_parameters
|
||||
|
@ -104,7 +103,6 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c
|
|||
tree_param.features,
|
||||
tree_param.thresholds,
|
||||
tree_param.values,
|
||||
tree_param.missings,
|
||||
extra_config,
|
||||
)
|
||||
for tree_param in tree_parameters
|
||||
|
@ -163,8 +161,8 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c
|
|||
|
||||
# Generate the tree implementation based on the selected strategy.
|
||||
if tree_type == TreeImpl.gemm:
|
||||
return GEMMGBDTImpl(operator, net_parameters, n_features, classes, missing_val, extra_config)
|
||||
return GEMMGBDTImpl(operator, net_parameters, n_features, classes, extra_config)
|
||||
if tree_type == TreeImpl.tree_trav:
|
||||
return TreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, missing_val, extra_config)
|
||||
return TreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, extra_config)
|
||||
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav.
|
||||
return PerfectTreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, missing_val, extra_config)
|
||||
return PerfectTreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, extra_config)
|
||||
|
|
|
@ -31,7 +31,6 @@ class Node:
|
|||
feature: The feature used to make a decision (if not leaf node, ignored otherwise)
|
||||
threshold: The threshold used in the decision (if not leaf node, ignored otherwise)
|
||||
value: The value stored in the leaf (ignored if not leaf node).
|
||||
missing: In the case of a missingle value chosen child node id
|
||||
"""
|
||||
self.id = id
|
||||
self.left = None
|
||||
|
@ -39,7 +38,6 @@ class Node:
|
|||
self.feature = None
|
||||
self.threshold = None
|
||||
self.value = None
|
||||
self.missing = None
|
||||
|
||||
|
||||
class TreeParameters:
|
||||
|
@ -47,7 +45,7 @@ class TreeParameters:
|
|||
Class containing a convenient in-memory representation of a decision tree.
|
||||
"""
|
||||
|
||||
def __init__(self, lefts, rights, features, thresholds, values, missings=None):
|
||||
def __init__(self, lefts, rights, features, thresholds, values):
|
||||
"""
|
||||
Args:
|
||||
lefts: The id of the left nodes
|
||||
|
@ -55,14 +53,12 @@ class TreeParameters:
|
|||
feature: The features used to make decisions
|
||||
thresholds: The thresholds used in the decisions
|
||||
values: The value stored in the leaves
|
||||
missings: In the case of a missing value which child node to select
|
||||
"""
|
||||
self.lefts = lefts
|
||||
self.rights = rights
|
||||
self.features = features
|
||||
self.thresholds = thresholds
|
||||
self.values = values
|
||||
self.missings = missings
|
||||
|
||||
|
||||
def _find_max_depth(tree_parameters):
|
||||
|
@ -188,7 +184,7 @@ def get_parameters_for_sklearn_common(tree_infos):
|
|||
return TreeParameters(lefts, rights, features, thresholds, values)
|
||||
|
||||
|
||||
def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values, missings=None, extra_config={}):
|
||||
def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values, extra_config={}):
|
||||
"""
|
||||
Common functions used by all tree algorithms to generate the parameters according to the tree_trav strategies.
|
||||
|
||||
|
@ -198,7 +194,7 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val
|
|||
features: The features used in the decision nodes
|
||||
thresholds: The thresholds used in the decision nodes
|
||||
values: The values stored in the leaf nodes
|
||||
missings: In the case of a missing value which child node to select
|
||||
|
||||
Returns:
|
||||
An array containing the extracted parameters
|
||||
"""
|
||||
|
@ -209,26 +205,18 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val
|
|||
rights = [2, -1, -1]
|
||||
features = [0, 0, 0]
|
||||
thresholds = [0, 0, 0]
|
||||
if missings is not None:
|
||||
missings = [2, -1, -1]
|
||||
n_classes = values.shape[1] if type(values) is np.ndarray else 1
|
||||
values = np.array([np.zeros(n_classes), values[0], values[0]])
|
||||
values.reshape(3, n_classes)
|
||||
|
||||
ids = [i for i in range(len(lefts))]
|
||||
if missings is not None:
|
||||
nodes = list(zip(ids, lefts, rights, features, thresholds, values, missings))
|
||||
else:
|
||||
nodes = list(zip(ids, lefts, rights, features, thresholds, values))
|
||||
nodes = list(zip(ids, lefts, rights, features, thresholds, values))
|
||||
|
||||
# Refactor the tree parameters in the proper format.
|
||||
nodes_map = {0: Node(0)}
|
||||
current_node = 0
|
||||
for i, node in enumerate(nodes):
|
||||
if missings is not None:
|
||||
id, left, right, feature, threshold, value, missing = node
|
||||
else:
|
||||
id, left, right, feature, threshold, value = node
|
||||
id, left, right, feature, threshold, value = node
|
||||
|
||||
if left != -1:
|
||||
l_node = Node(left)
|
||||
|
@ -252,13 +240,6 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val
|
|||
nodes_map[current_node].threshold = threshold
|
||||
nodes_map[current_node].value = value
|
||||
|
||||
if missings is not None:
|
||||
m_node = l_node if missing == left else r_node
|
||||
nodes_map[current_node].missing = m_node
|
||||
|
||||
if missings[i] == -1:
|
||||
missings[i] = id
|
||||
|
||||
current_node += 1
|
||||
|
||||
lefts = np.array(lefts)
|
||||
|
@ -266,13 +247,11 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val
|
|||
features = np.array(features)
|
||||
thresholds = np.array(thresholds)
|
||||
values = np.array(values)
|
||||
if missings is not None:
|
||||
missings = np.array(missings)
|
||||
|
||||
return [nodes_map, ids, lefts, rights, features, thresholds, values, missings]
|
||||
return [nodes_map, ids, lefts, rights, features, thresholds, values]
|
||||
|
||||
|
||||
def get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, missings=None, classes=None, extra_config={}):
|
||||
def get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, classes=None, extra_config={}):
|
||||
"""
|
||||
This function is used to generate tree parameters for sklearn trees.
|
||||
Includes SklearnRandomForestClassifier/Regressor, and SklearnGradientBoostingClassifier.
|
||||
|
@ -283,7 +262,6 @@ def get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, va
|
|||
features: The features used in the decision nodes
|
||||
thresholds: The thresholds used in the decision nodes
|
||||
values: The values stored in the leaf nodes
|
||||
missings: In the case of a missing value which child node to select
|
||||
classes: The list of class labels. None if regression model
|
||||
Returns:
|
||||
An array containing the extracted parameters
|
||||
|
@ -298,10 +276,10 @@ def get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, va
|
|||
if constants.NUM_TREES in extra_config:
|
||||
values /= extra_config[constants.NUM_TREES]
|
||||
|
||||
return get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values, missings)
|
||||
return get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values)
|
||||
|
||||
|
||||
def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values, n_features, missings=None, extra_config={}):
|
||||
def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values, n_features, extra_config={}):
|
||||
"""
|
||||
Common functions used by all tree algorithms to generate the parameters according to the GEMM strategy.
|
||||
|
||||
|
@ -312,7 +290,7 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values,
|
|||
thresholds: The thresholds used in the decision nodes
|
||||
values: The values stored in the leaf nodes
|
||||
n_features: The number of expected input features
|
||||
missings: In the case of a missing value which child node to select
|
||||
|
||||
Returns:
|
||||
The weights and bias for the GEMM implementation
|
||||
"""
|
||||
|
@ -327,35 +305,20 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values,
|
|||
rights = [2, -1, -1]
|
||||
features = [0, 0, 0]
|
||||
thresholds = [0, 0, 0]
|
||||
if missings is not None:
|
||||
missings = [2, -1, -1]
|
||||
n_classes = values.shape[1]
|
||||
values = np.array([np.zeros(n_classes), values[0], values[0]])
|
||||
values.reshape(3, n_classes)
|
||||
|
||||
if missings is None:
|
||||
missings = rights
|
||||
|
||||
# First hidden layer has all inequalities.
|
||||
hidden_weights = []
|
||||
hidden_biases = []
|
||||
hidden_missing_biases = []
|
||||
for left, right, missing, feature, thresh in zip(lefts, rights, missings, features, thresholds):
|
||||
if left != -1 or right != -1:
|
||||
for left, feature, thresh in zip(lefts, features, thresholds):
|
||||
if left != -1:
|
||||
hidden_weights.append([1 if i == feature else 0 for i in range(n_features)])
|
||||
hidden_biases.append(thresh)
|
||||
|
||||
if missing == right:
|
||||
hidden_missing_biases.append(1)
|
||||
else:
|
||||
hidden_missing_biases.append(0)
|
||||
weights.append(np.array(hidden_weights).astype("float32"))
|
||||
biases.append(np.array(hidden_biases).astype("float32"))
|
||||
|
||||
# Missing value handling biases.
|
||||
weights.append(None)
|
||||
biases.append(np.array(hidden_missing_biases).astype("float32"))
|
||||
|
||||
n_splits = len(hidden_weights)
|
||||
|
||||
# Second hidden layer has ANDs for each leaf of the decision tree.
|
||||
|
@ -381,10 +344,10 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values,
|
|||
for j, p in enumerate(path[:-1]):
|
||||
num_leaves_before_p = list(lefts[:p]).count(-1)
|
||||
if path[j + 1] in lefts:
|
||||
vec[p - num_leaves_before_p] = -1
|
||||
elif path[j + 1] in rights:
|
||||
num_positive += 1
|
||||
vec[p - num_leaves_before_p] = 1
|
||||
num_positive += 1
|
||||
elif path[j + 1] in rights:
|
||||
vec[p - num_leaves_before_p] = -1
|
||||
else:
|
||||
raise RuntimeError("Inconsistent state encountered while tree translation.")
|
||||
|
||||
|
@ -433,7 +396,6 @@ def convert_decision_ensemble_tree_common(
|
|||
tree_param.thresholds,
|
||||
tree_param.values,
|
||||
n_features,
|
||||
tree_param.missings,
|
||||
extra_config,
|
||||
)
|
||||
for tree_param in tree_parameters
|
||||
|
@ -442,7 +404,7 @@ def convert_decision_ensemble_tree_common(
|
|||
|
||||
net_parameters = [
|
||||
get_parameters_for_tree_trav(
|
||||
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, tree_param.missings, extra_config,
|
||||
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, extra_config,
|
||||
)
|
||||
for tree_param in tree_parameters
|
||||
]
|
||||
|
|
|
@ -53,14 +53,13 @@ class AbstractPyTorchTreeImpl(AbstracTreeImpl, torch.nn.Module):
|
|||
Abstract class definig the basic structure for tree-base models implemented in PyTorch.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes, missing_val, **kwargs):
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
n_classes: The total number of used classes
|
||||
missing_val: The value to be treated as the missing value
|
||||
"""
|
||||
super(AbstractPyTorchTreeImpl, self).__init__(logical_operator, **kwargs)
|
||||
|
||||
|
@ -84,32 +83,23 @@ class AbstractPyTorchTreeImpl(AbstracTreeImpl, torch.nn.Module):
|
|||
self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False)
|
||||
self.perform_class_select = True
|
||||
|
||||
self.missing_val = missing_val
|
||||
if self.missing_val in [None, np.nan]:
|
||||
self.missing_val_op = torch.isnan
|
||||
else:
|
||||
def missing_val_op(x):
|
||||
return x == self.missing_val
|
||||
self.missing_val_op = missing_val_op
|
||||
|
||||
|
||||
class GEMMTreeImpl(AbstractPyTorchTreeImpl):
|
||||
"""
|
||||
Class implementing the GEMM strategy in PyTorch for tree-base models.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes=None, missing_val=None, extra_config={}, **kwargs):
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes=None, extra_config={}, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
n_classes: The total number of used classes
|
||||
missing_val: The value to be treated as the missing value
|
||||
"""
|
||||
# If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes.
|
||||
n_classes = n_classes if n_classes is not None else tree_parameters[0][0][3].shape[0]
|
||||
super(GEMMTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, missing_val, **kwargs)
|
||||
n_classes = n_classes if n_classes is not None else tree_parameters[0][0][2].shape[0]
|
||||
super(GEMMTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs)
|
||||
|
||||
# Initialize the actual model.
|
||||
hidden_one_size = 0
|
||||
|
@ -118,24 +108,22 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
|
|||
|
||||
for weight, bias in tree_parameters:
|
||||
hidden_one_size = max(hidden_one_size, weight[0].shape[0])
|
||||
hidden_two_size = max(hidden_two_size, weight[2].shape[0])
|
||||
hidden_two_size = max(hidden_two_size, weight[1].shape[0])
|
||||
|
||||
n_trees = len(tree_parameters)
|
||||
weight_1 = np.zeros((n_trees, hidden_one_size))
|
||||
weight_1 = np.zeros((n_trees, hidden_one_size, n_features))
|
||||
bias_1 = np.zeros((n_trees, hidden_one_size))
|
||||
missing_bias_1 = np.zeros((n_trees, hidden_one_size))
|
||||
weight_2 = np.zeros((n_trees, hidden_two_size, hidden_one_size))
|
||||
bias_2 = np.zeros((n_trees, hidden_two_size))
|
||||
weight_3 = np.zeros((n_trees, hidden_three_size, hidden_two_size))
|
||||
|
||||
for i, (weight, bias) in enumerate(tree_parameters):
|
||||
if len(weight[0]) > 0:
|
||||
weight_1[i, 0 : weight[0].shape[0]] = np.argmax(weight[0], axis=1)
|
||||
weight_1[i, 0 : weight[0].shape[0], 0 : weight[0].shape[1]] = weight[0]
|
||||
bias_1[i, 0 : bias[0].shape[0]] = bias[0]
|
||||
missing_bias_1[i, 0 : bias[1].shape[0]] = bias[1]
|
||||
weight_2[i, 0 : weight[2].shape[0], 0 : weight[2].shape[1]] = weight[2]
|
||||
bias_2[i, 0 : bias[2].shape[0]] = bias[2]
|
||||
weight_3[i, 0 : weight[3].shape[0], 0 : weight[3].shape[1]] = weight[3]
|
||||
weight_2[i, 0 : weight[1].shape[0], 0 : weight[1].shape[1]] = weight[1]
|
||||
bias_2[i, 0 : bias[1].shape[0]] = bias[1]
|
||||
weight_3[i, 0 : weight[2].shape[0], 0 : weight[2].shape[1]] = weight[2]
|
||||
|
||||
self.n_trees = n_trees
|
||||
self.n_features = n_features
|
||||
|
@ -143,19 +131,13 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
|
|||
self.hidden_two_size = hidden_two_size
|
||||
self.hidden_three_size = hidden_three_size
|
||||
|
||||
self.weight_1 = torch.nn.Parameter(torch.from_numpy(weight_1.reshape(-1).astype("int64")), requires_grad=False)
|
||||
self.bias_1 = torch.nn.Parameter(torch.from_numpy(bias_1.reshape(1, -1).astype("float32")), requires_grad=False)
|
||||
self.weight_1 = torch.nn.Parameter(torch.from_numpy(weight_1.reshape(-1, self.n_features).astype("float32")))
|
||||
self.bias_1 = torch.nn.Parameter(torch.from_numpy(bias_1.reshape(-1, 1).astype("float32")))
|
||||
|
||||
# By default when we compare nan to any value the output will be false. Thus we need to explicitly
|
||||
# account for missing values only when missings are different to lefts (i.e., False condition)
|
||||
if np.sum(missing_bias_1) != 0:
|
||||
self.missing_bias_1 = torch.nn.Parameter(torch.from_numpy(missing_bias_1.reshape(1, -1).astype("float32")), requires_grad=False)
|
||||
else:
|
||||
self.missing_bias_1 = None
|
||||
self.weight_2 = torch.nn.Parameter(torch.from_numpy(weight_2.astype("float32")))
|
||||
self.bias_2 = torch.nn.Parameter(torch.from_numpy(bias_2.reshape(-1, 1).astype("float32")))
|
||||
|
||||
self.weight_2 = torch.nn.Parameter(torch.from_numpy(weight_2.astype("float32")), requires_grad=False)
|
||||
self.bias_2 = torch.nn.Parameter(torch.from_numpy(bias_2.reshape(-1, 1).astype("float32")), requires_grad=False)
|
||||
self.weight_3 = torch.nn.Parameter(torch.from_numpy(weight_3.astype("float32")), requires_grad=False)
|
||||
self.weight_3 = torch.nn.Parameter(torch.from_numpy(weight_3.astype("float32")))
|
||||
|
||||
# We register also base_prediction here so that tensor will be moved to the proper hardware with the model.
|
||||
# i.e., if cuda is selected, the parameter will be automatically moved on the GPU.
|
||||
|
@ -166,12 +148,10 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
|
|||
return x
|
||||
|
||||
def forward(self, x):
|
||||
features = torch.index_select(x, 1, self.weight_1)
|
||||
if self.missing_bias_1 is not None:
|
||||
x = torch.where(self.missing_val_op(features), self.missing_bias_1 + torch.zeros_like(features), (features >= self.bias_1).float())
|
||||
else:
|
||||
x = (features >= self.bias_1).float()
|
||||
x = x.view(-1, self.n_trees * self.hidden_one_size).t().view(self.n_trees, self.hidden_one_size, -1)
|
||||
x = x.t()
|
||||
x = torch.mm(self.weight_1, x) < self.bias_1
|
||||
x = x.view(self.n_trees, self.hidden_one_size, -1)
|
||||
x = x.float()
|
||||
|
||||
x = torch.matmul(self.weight_2, x)
|
||||
|
||||
|
@ -207,7 +187,9 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
indexes = indexes.expand(batch_size, self.num_trees)
|
||||
return indexes.reshape(-1)
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes, n_classes=None, missing_val=None, extra_config={}, **kwargs):
|
||||
def __init__(
|
||||
self, logical_operator, tree_parameters, max_depth, n_features, classes, n_classes=None, extra_config={}, **kwargs
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
|
@ -215,12 +197,13 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
n_classes: The total number of used classes
|
||||
missing_val: The value to be treated as the missing value
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
# If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes.
|
||||
n_classes = n_classes if n_classes is not None else tree_parameters[0][6].shape[1]
|
||||
super(TreeTraversalTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, missing_val, **kwargs)
|
||||
super(TreeTraversalTreeImpl, self).__init__(
|
||||
logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs
|
||||
)
|
||||
|
||||
# Initialize the actual model.
|
||||
self.n_features = n_features
|
||||
|
@ -235,28 +218,16 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
thresholds = np.zeros((self.num_trees, self.num_nodes), dtype=np.float32)
|
||||
values = np.zeros((self.num_trees, self.num_nodes, self.n_classes), dtype=np.float32)
|
||||
|
||||
missings = None
|
||||
if len(tree_parameters[0]) == 8 and tree_parameters[0][7] is not None:
|
||||
missings = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64)
|
||||
|
||||
for i in range(self.num_trees):
|
||||
lefts[i][: len(tree_parameters[i][0])] = tree_parameters[i][2]
|
||||
rights[i][: len(tree_parameters[i][0])] = tree_parameters[i][3]
|
||||
features[i][: len(tree_parameters[i][0])] = tree_parameters[i][4]
|
||||
thresholds[i][: len(tree_parameters[i][0])] = tree_parameters[i][5]
|
||||
values[i][: len(tree_parameters[i][0])][:] = tree_parameters[i][6]
|
||||
if missings is not None:
|
||||
missings[i][: len(tree_parameters[i][0])] = tree_parameters[i][7]
|
||||
|
||||
self.lefts = torch.nn.Parameter(torch.from_numpy(lefts).view(-1), requires_grad=False)
|
||||
self.rights = torch.nn.Parameter(torch.from_numpy(rights).view(-1), requires_grad=False)
|
||||
|
||||
# By default when we compare nan to any value the output will be false. Thus we need to explicitly
|
||||
# account for missing values when missings are different to lefts (i.e., false condition)
|
||||
self.missings = None
|
||||
if missings is not None and not np.allclose(lefts, missings):
|
||||
self.missings = torch.nn.Parameter(torch.from_numpy(missings).view(-1), requires_grad=False)
|
||||
|
||||
self.features = torch.nn.Parameter(torch.from_numpy(features).view(-1), requires_grad=False)
|
||||
self.thresholds = torch.nn.Parameter(torch.from_numpy(thresholds).view(-1))
|
||||
self.values = torch.nn.Parameter(torch.from_numpy(values).view(-1, self.n_classes))
|
||||
|
@ -284,12 +255,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
lefts = torch.index_select(self.lefts, 0, indexes).view(-1, self.num_trees)
|
||||
rights = torch.index_select(self.rights, 0, indexes).view(-1, self.num_trees)
|
||||
|
||||
if self.missings is not None:
|
||||
missings = torch.index_select(self.missings, 0, indexes).view(-1, self.num_trees)
|
||||
|
||||
indexes = torch.where(torch.ge(feature_values, thresholds), rights, lefts).long()
|
||||
if self.missings is not None:
|
||||
indexes = torch.where(self.missing_val_op(feature_values), missings, indexes)
|
||||
indexes = indexes + self.nodes_offset
|
||||
indexes = indexes.view(-1)
|
||||
|
||||
|
@ -315,19 +281,22 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
Class implementing the Perfect Tree Traversal strategy in PyTorch for tree-base models.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes, n_classes=None, missing_val=None, extra_config={}, **kwargs):
|
||||
def __init__(
|
||||
self, logical_operator, tree_parameters, max_depth, n_features, classes, n_classes=None, extra_config={}, **kwargs
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
max_depth: The maximum tree-depth in the model
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
missing_val: The value to be treated as the missing value
|
||||
n_classes: The total number of used classes
|
||||
"""
|
||||
# If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes.
|
||||
n_classes = n_classes if n_classes is not None else tree_parameters[0][6].shape[1]
|
||||
super(PerfectTreeTraversalTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, missing_val, **kwargs)
|
||||
super(PerfectTreeTraversalTreeImpl, self).__init__(
|
||||
logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs
|
||||
)
|
||||
|
||||
# Initialize the actual model.
|
||||
self.max_tree_depth = max_depth
|
||||
|
@ -336,53 +305,40 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
|
||||
node_maps = [tp[0] for tp in tree_parameters]
|
||||
|
||||
feature_ids = np.zeros((self.num_trees, 2 ** max_depth - 1))
|
||||
threshold_vals = np.zeros((self.num_trees, 2 ** max_depth - 1))
|
||||
leaf_vals = np.zeros((self.num_trees, 2 ** max_depth, self.n_classes))
|
||||
missings = np.zeros((self.num_trees, 2 ** max_depth - 1), dtype=np.int64)
|
||||
weight_0 = np.zeros((self.num_trees, 2 ** max_depth - 1))
|
||||
bias_0 = np.zeros((self.num_trees, 2 ** max_depth - 1))
|
||||
weight_1 = np.zeros((self.num_trees, 2 ** max_depth, self.n_classes))
|
||||
|
||||
# By default when we compare nan to any value the output will be false. Thus, in `_populate_structure_tensors` we check whether there are
|
||||
# non-trivial missings that are different to lefts (i.e., false condition) and set the `self.has_non_trivial_missing_vals` to True.
|
||||
self.has_non_trivial_missing_vals = False
|
||||
for i, node_map in enumerate(node_maps):
|
||||
self._populate_structure_tensors(node_map, max_depth, feature_ids[i], leaf_vals[i], threshold_vals[i], missings[i])
|
||||
self._get_weights_and_biases(node_map, max_depth, weight_0[i], weight_1[i], bias_0[i])
|
||||
|
||||
node_by_levels = [set() for _ in range(max_depth)]
|
||||
self._traverse_by_level(node_by_levels, 0, -1, max_depth)
|
||||
|
||||
self.root_nodes = torch.nn.Parameter(torch.from_numpy(feature_ids[:, 0].flatten().astype("int64")), requires_grad=False)
|
||||
self.root_biases = torch.nn.Parameter(-1 * torch.from_numpy(threshold_vals[:, 0].astype("float32")), requires_grad=False)
|
||||
self.root_missing_node_ids = torch.nn.Parameter(torch.from_numpy(missings[:, 0].astype("int64")), requires_grad=False)
|
||||
self.root_nodes = torch.nn.Parameter(torch.from_numpy(weight_0[:, 0].flatten().astype("int64")), requires_grad=False)
|
||||
self.root_biases = torch.nn.Parameter(-1 * torch.from_numpy(bias_0[:, 0].astype("float32")), requires_grad=False)
|
||||
|
||||
tree_indices = np.array([i for i in range(0, 2 * self.num_trees, 2)]).astype("int64")
|
||||
self.tree_indices = torch.nn.Parameter(torch.from_numpy(tree_indices), requires_grad=False)
|
||||
|
||||
self.feature_ids = []
|
||||
self.threshold_vals = []
|
||||
self.missing_node_ids = []
|
||||
self.nodes = []
|
||||
self.biases = []
|
||||
for i in range(1, max_depth):
|
||||
features = torch.nn.Parameter(
|
||||
torch.from_numpy(feature_ids[:, list(sorted(node_by_levels[i]))].flatten().astype("int64")), requires_grad=False
|
||||
nodes = torch.nn.Parameter(
|
||||
torch.from_numpy(weight_0[:, list(sorted(node_by_levels[i]))].flatten().astype("int64")), requires_grad=False
|
||||
)
|
||||
thresholds = torch.nn.Parameter(
|
||||
torch.from_numpy(-1 * threshold_vals[:, list(sorted(node_by_levels[i]))].flatten().astype("float32")),
|
||||
requires_grad=False,
|
||||
)
|
||||
missing_nodes = torch.nn.Parameter(
|
||||
torch.from_numpy(missings[:, list(sorted(node_by_levels[i]))].flatten().astype("int64")),
|
||||
biases = torch.nn.Parameter(
|
||||
torch.from_numpy(-1 * bias_0[:, list(sorted(node_by_levels[i]))].flatten().astype("float32")),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.nodes.append(nodes)
|
||||
self.biases.append(biases)
|
||||
|
||||
self.feature_ids.append(features)
|
||||
self.threshold_vals.append(thresholds)
|
||||
self.missing_node_ids.append(missing_nodes)
|
||||
|
||||
self.feature_ids = torch.nn.ParameterList(self.feature_ids)
|
||||
self.threshold_vals = torch.nn.ParameterList(self.threshold_vals)
|
||||
self.missing_node_ids = torch.nn.ParameterList(self.missing_node_ids)
|
||||
self.nodes = torch.nn.ParameterList(self.nodes)
|
||||
self.biases = torch.nn.ParameterList(self.biases)
|
||||
|
||||
self.leaf_nodes = torch.nn.Parameter(
|
||||
torch.from_numpy(leaf_vals.reshape((-1, self.n_classes)).astype("float32")), requires_grad=False
|
||||
torch.from_numpy(weight_1.reshape((-1, self.n_classes)).astype("float32")), requires_grad=False
|
||||
)
|
||||
|
||||
# We register also base_prediction here so that tensor will be moved to the proper hardware with the model.
|
||||
|
@ -394,23 +350,15 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
return x
|
||||
|
||||
def forward(self, x):
|
||||
root_features = torch.index_select(x, 1, self.root_nodes)
|
||||
prev_indices = torch.ge(root_features, self.root_biases).long()
|
||||
if self.has_non_trivial_missing_vals:
|
||||
prev_indices = torch.where(self.missing_val_op(root_features), self.root_missing_node_ids, prev_indices)
|
||||
prev_indices = (torch.ge(torch.index_select(x, 1, self.root_nodes), self.root_biases)).long()
|
||||
prev_indices = prev_indices + self.tree_indices
|
||||
prev_indices = prev_indices.view(-1)
|
||||
|
||||
factor = 2
|
||||
for features, thresholds, missings in zip(self.feature_ids, self.threshold_vals, self.missing_node_ids):
|
||||
gather_indices = torch.index_select(features, 0, prev_indices).view(-1, self.num_trees)
|
||||
for nodes, biases in zip(self.nodes, self.biases):
|
||||
gather_indices = torch.index_select(nodes, 0, prev_indices).view(-1, self.num_trees)
|
||||
features = torch.gather(x, 1, gather_indices).view(-1)
|
||||
thresholds = torch.index_select(thresholds, 0, prev_indices)
|
||||
node_eval_status = torch.ge(features, thresholds).long()
|
||||
if self.has_non_trivial_missing_vals:
|
||||
missings = torch.index_select(missings, 0, prev_indices)
|
||||
node_eval_status = torch.where(self.missing_val_op(features), missings, node_eval_status)
|
||||
prev_indices = factor * prev_indices + node_eval_status
|
||||
prev_indices = factor * prev_indices + torch.ge(features, torch.index_select(biases, 0, prev_indices)).long()
|
||||
|
||||
output = torch.index_select(self.leaf_nodes, 0, prev_indices).view(-1, self.num_trees, self.n_classes)
|
||||
|
||||
|
@ -438,24 +386,17 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
node_id = self._traverse_by_level(node_by_levels, node_id, current_level, max_level)
|
||||
return node_id
|
||||
|
||||
def _populate_structure_tensors(self, nodes_map, tree_depth, node_ids, threshold_vals, leaf_vals, missings):
|
||||
def _get_weights_and_biases(self, nodes_map, tree_depth, weight_0, weight_1, bias_0):
|
||||
def depth_f_traversal(node, current_depth, node_id, leaf_start_id):
|
||||
node_ids[node_id] = node.feature
|
||||
leaf_vals[node_id] = -node.threshold
|
||||
|
||||
if node.missing is None or node.left == node.missing:
|
||||
missings[node_id] = 0
|
||||
else:
|
||||
missings[node_id] = 1
|
||||
self.has_non_trivial_missing_vals = True
|
||||
|
||||
weight_0[node_id] = node.feature
|
||||
bias_0[node_id] = -node.threshold
|
||||
current_depth += 1
|
||||
node_id += 1
|
||||
|
||||
if node.left.feature == -1:
|
||||
node_id += 2 ** (tree_depth - current_depth - 1) - 1
|
||||
v = node.left.value
|
||||
threshold_vals[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = (
|
||||
weight_1[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = (
|
||||
np.ones((2 ** (tree_depth - current_depth - 1), self.n_classes)) * v
|
||||
)
|
||||
leaf_start_id += 2 ** (tree_depth - current_depth - 1)
|
||||
|
@ -465,7 +406,7 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
if node.right.feature == -1:
|
||||
node_id += 2 ** (tree_depth - current_depth - 1) - 1
|
||||
v = node.right.value
|
||||
threshold_vals[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = (
|
||||
weight_1[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = (
|
||||
np.ones((2 ** (tree_depth - current_depth - 1), self.n_classes)) * v
|
||||
)
|
||||
leaf_start_id += 2 ** (tree_depth - current_depth - 1)
|
||||
|
@ -484,15 +425,14 @@ class GEMMDecisionTreeImpl(GEMMTreeImpl):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes=None, missing_val=None):
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes=None):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
missing_val: The value to be treated as the missing value
|
||||
"""
|
||||
super(GEMMDecisionTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, missing_val)
|
||||
super(GEMMDecisionTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes)
|
||||
|
||||
def aggregation(self, x):
|
||||
output = x.sum(0).t()
|
||||
|
@ -505,14 +445,13 @@ class TreeTraversalDecisionTreeImpl(TreeTraversalTreeImpl):
|
|||
Class implementing the Tree Traversal strategy in PyTorch for decision tree models.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, missing_val=None, extra_config={}):
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
max_depth: The maximum tree-depth in the model
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
missing_val: The value to be treated as the missing value
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(TreeTraversalDecisionTreeImpl, self).__init__(
|
||||
|
@ -530,16 +469,17 @@ class PerfectTreeTraversalDecisionTreeImpl(PerfectTreeTraversalTreeImpl):
|
|||
Class implementing the Perfect Tree Traversal strategy in PyTorch for decision tree models.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, missing_val=None):
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
max_depth: The maximum tree-depth in the model
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
missing_val: The value to be treated as the missing value
|
||||
"""
|
||||
super(PerfectTreeTraversalDecisionTreeImpl, self).__init__(logical_operator, tree_parameters, max_depth, n_features, classes, missing_val)
|
||||
super(PerfectTreeTraversalDecisionTreeImpl, self).__init__(
|
||||
logical_operator, tree_parameters, max_depth, n_features, classes
|
||||
)
|
||||
|
||||
def aggregation(self, x):
|
||||
output = x.sum(1)
|
||||
|
@ -553,16 +493,15 @@ class GEMMGBDTImpl(GEMMTreeImpl):
|
|||
Class implementing the GEMM strategy (in PyTorch) for GBDT models.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes=None, missing_val=None, extra_config={}):
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes=None, extra_config={}):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
missing_val: The value to be treated as the missing value
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(GEMMGBDTImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, 1, missing_val, extra_config)
|
||||
super(GEMMGBDTImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, 1, extra_config)
|
||||
|
||||
self.n_gbdt_classes = 1
|
||||
self.post_transform = lambda x: x
|
||||
|
@ -586,17 +525,18 @@ class TreeTraversalGBDTImpl(TreeTraversalTreeImpl):
|
|||
Class implementing the Tree Traversal strategy in PyTorch.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, max_detph, n_features, classes=None, missing_val=None, extra_config={}):
|
||||
def __init__(self, logical_operator, tree_parameters, max_detph, n_features, classes=None, extra_config={}):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
max_depth: The maximum tree-depth in the model
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
missing_val: The value to be treated as the missing_val value
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(TreeTraversalGBDTImpl, self).__init__(logical_operator, tree_parameters, max_detph, n_features, classes, 1, missing_val, extra_config)
|
||||
super(TreeTraversalGBDTImpl, self).__init__(
|
||||
logical_operator, tree_parameters, max_detph, n_features, classes, 1, extra_config
|
||||
)
|
||||
|
||||
self.n_gbdt_classes = 1
|
||||
self.post_transform = lambda x: x
|
||||
|
@ -620,17 +560,18 @@ class PerfectTreeTraversalGBDTImpl(PerfectTreeTraversalTreeImpl):
|
|||
Class implementing the Perfect Tree Traversal strategy in PyTorch.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, missing_val=None, extra_config={}):
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
max_depth: The maximum tree-depth in the model
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
missing_val: The value to be treated as the missing_val value
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(PerfectTreeTraversalGBDTImpl, self).__init__(logical_operator, tree_parameters, max_depth, n_features, classes, 1, missing_val, extra_config)
|
||||
super(PerfectTreeTraversalGBDTImpl, self).__init__(
|
||||
logical_operator, tree_parameters, max_depth, n_features, classes, 1, extra_config
|
||||
)
|
||||
|
||||
self.n_gbdt_classes = 1
|
||||
self.post_transform = lambda x: x
|
||||
|
|
|
@ -16,7 +16,7 @@ from ._gbdt_commons import convert_gbdt_classifier_common, convert_gbdt_common
|
|||
from ._tree_commons import TreeParameters
|
||||
|
||||
|
||||
def _tree_traversal(node, lefts, rights, features, thresholds, values, missings, count):
|
||||
def _tree_traversal(node, lefts, rights, features, thresholds, values, count):
|
||||
"""
|
||||
Recursive function for parsing a tree and filling the input data structures.
|
||||
"""
|
||||
|
@ -26,23 +26,16 @@ def _tree_traversal(node, lefts, rights, features, thresholds, values, missings,
|
|||
values.append([-1])
|
||||
lefts.append(count + 1)
|
||||
rights.append(-1)
|
||||
missings.append(-1)
|
||||
pos = len(rights) - 1
|
||||
count = _tree_traversal(node["left_child"], lefts, rights, features, thresholds, values, missings, count + 1)
|
||||
count = _tree_traversal(node["left_child"], lefts, rights, features, thresholds, values, count + 1)
|
||||
rights[pos] = count + 1
|
||||
if node['missing_type'] == 'None':
|
||||
# Missing values not present in training data are treated as zeros during inference.
|
||||
missings[pos] = lefts[pos] if 0 < node["threshold"] else rights[pos]
|
||||
else:
|
||||
missings[pos] = lefts[pos] if node["default_left"] else rights[pos]
|
||||
return _tree_traversal(node["right_child"], lefts, rights, features, thresholds, values, missings, count + 1)
|
||||
return _tree_traversal(node["right_child"], lefts, rights, features, thresholds, values, count + 1)
|
||||
else:
|
||||
features.append(0)
|
||||
thresholds.append(0)
|
||||
values.append([node["leaf_value"]])
|
||||
lefts.append(-1)
|
||||
rights.append(-1)
|
||||
missings.append(-1)
|
||||
return count
|
||||
|
||||
|
||||
|
@ -55,10 +48,9 @@ def _get_tree_parameters(tree_info):
|
|||
features = []
|
||||
thresholds = []
|
||||
values = []
|
||||
missings = []
|
||||
_tree_traversal(tree_info["tree_structure"], lefts, rights, features, thresholds, values, missings, 0)
|
||||
_tree_traversal(tree_info["tree_structure"], lefts, rights, features, thresholds, values, 0)
|
||||
|
||||
return TreeParameters(lefts, rights, features, thresholds, values, missings)
|
||||
return TreeParameters(lefts, rights, features, thresholds, values)
|
||||
|
||||
|
||||
def convert_sklearn_lgbm_classifier(operator, device, extra_config):
|
||||
|
@ -74,14 +66,14 @@ def convert_sklearn_lgbm_classifier(operator, device, extra_config):
|
|||
A PyTorch model
|
||||
"""
|
||||
assert operator is not None, "Cannot convert None operator"
|
||||
assert not hasattr(operator.raw_operator, "use_missing") or operator.raw_operator.use_missing
|
||||
assert not hasattr(operator.raw_operator, "zero_as_missing") or not operator.raw_operator.zero_as_missing
|
||||
|
||||
n_features = operator.raw_operator._n_features
|
||||
tree_infos = operator.raw_operator.booster_.dump_model()["tree_info"]
|
||||
n_classes = operator.raw_operator._n_classes
|
||||
|
||||
return convert_gbdt_classifier_common(operator, tree_infos, _get_tree_parameters, n_features, n_classes, missing_val=None, extra_config=extra_config)
|
||||
return convert_gbdt_classifier_common(
|
||||
operator, tree_infos, _get_tree_parameters, n_features, n_classes, extra_config=extra_config
|
||||
)
|
||||
|
||||
|
||||
def convert_sklearn_lgbm_regressor(operator, device, extra_config):
|
||||
|
@ -97,8 +89,6 @@ def convert_sklearn_lgbm_regressor(operator, device, extra_config):
|
|||
A PyTorch model
|
||||
"""
|
||||
assert operator is not None, "Cannot convert None operator"
|
||||
assert not hasattr(operator.raw_operator, "use_missing") or operator.raw_operator.use_missing
|
||||
assert not hasattr(operator.raw_operator, "zero_as_missing") or not operator.raw_operator.zero_as_missing
|
||||
|
||||
# Get tree information out of the model.
|
||||
n_features = operator.raw_operator._n_features
|
||||
|
@ -106,7 +96,7 @@ def convert_sklearn_lgbm_regressor(operator, device, extra_config):
|
|||
if operator.raw_operator._objective == "tweedie":
|
||||
extra_config[constants.POST_TRANSFORM] = constants.TWEEDIE
|
||||
|
||||
return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, missing_val=None, extra_config=extra_config)
|
||||
return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, extra_config=extra_config)
|
||||
|
||||
|
||||
# Register the converters.
|
||||
|
|
|
@ -214,14 +214,7 @@ def convert_onnx_tree_ensemble_classifier(operator, device=None, extra_config={}
|
|||
)
|
||||
extra_config[constants.POST_TRANSFORM] = post_transform
|
||||
return convert_gbdt_classifier_common(
|
||||
operator,
|
||||
tree_infos,
|
||||
_dummy_get_parameter,
|
||||
n_features,
|
||||
len(classes),
|
||||
classes,
|
||||
missing_val=None,
|
||||
extra_config=extra_config,
|
||||
operator, tree_infos, _dummy_get_parameter, n_features, len(classes), classes, extra_config
|
||||
)
|
||||
|
||||
|
||||
|
@ -243,9 +236,7 @@ def convert_onnx_tree_ensemble_regressor(operator, device=None, extra_config={})
|
|||
n_features, tree_infos, _, _ = _get_tree_infos_from_tree_ensemble(operator.raw_operator, device, extra_config)
|
||||
|
||||
# Generate the model.
|
||||
return convert_gbdt_common(
|
||||
operator, tree_infos, _dummy_get_parameter, n_features, missing_val=None, extra_config=extra_config
|
||||
)
|
||||
return convert_gbdt_common(operator, tree_infos, _dummy_get_parameter, n_features, extra_config=extra_config)
|
||||
|
||||
|
||||
register_converter("ONNXMLTreeEnsembleClassifier", convert_onnx_tree_ensemble_classifier)
|
||||
|
|
|
@ -42,8 +42,8 @@ def convert_sklearn_random_forest_classifier(operator, device, extra_config):
|
|||
if not all(isinstance(c, int) for c in classes):
|
||||
raise RuntimeError("Random Forest Classifier translation only supports integer class labels")
|
||||
|
||||
def get_parameters_for_tree_trav(lefts, rights, features, thresholds, values, missings=None, extra_config={}):
|
||||
return get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, missings, classes, extra_config)
|
||||
def get_parameters_for_tree_trav(lefts, rights, features, thresholds, values, extra_config={}):
|
||||
return get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, classes, extra_config)
|
||||
|
||||
return convert_decision_ensemble_tree_common(
|
||||
operator,
|
||||
|
@ -77,8 +77,8 @@ def convert_sklearn_random_forest_regressor(operator, device, extra_config):
|
|||
# For Sklearn Trees we need to know how many trees are there for normalization.
|
||||
extra_config[constants.NUM_TREES] = len(tree_infos)
|
||||
|
||||
def get_parameters_for_tree_trav(lefts, rights, features, thresholds, values, missings=None, extra_config={}):
|
||||
return get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, missings, None, extra_config)
|
||||
def get_parameters_for_tree_trav(lefts, rights, features, thresholds, values, extra_config={}):
|
||||
return get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, None, extra_config)
|
||||
|
||||
return convert_decision_ensemble_tree_common(
|
||||
operator,
|
||||
|
|
|
@ -90,7 +90,7 @@ def convert_sklearn_gbdt_classifier(operator, device, extra_config):
|
|||
extra_config[constants.REORDER_TREES] = False
|
||||
|
||||
return convert_gbdt_classifier_common(
|
||||
operator, tree_infos, get_parameters_for_sklearn_common, n_features, n_classes, classes, missing_val=None, extra_config=extra_config
|
||||
operator, tree_infos, get_parameters_for_sklearn_common, n_features, n_classes, classes, extra_config
|
||||
)
|
||||
|
||||
|
||||
|
@ -126,7 +126,7 @@ def convert_sklearn_gbdt_regressor(operator, device, extra_config):
|
|||
|
||||
extra_config[constants.BASE_PREDICTION] = base_prediction
|
||||
|
||||
return convert_gbdt_common(operator, tree_infos, get_parameters_for_sklearn_common, n_features, missing_val=None, extra_config=extra_config)
|
||||
return convert_gbdt_common(operator, tree_infos, get_parameters_for_sklearn_common, n_features, None, extra_config)
|
||||
|
||||
|
||||
def convert_sklearn_hist_gbdt_classifier(operator, device, extra_config):
|
||||
|
@ -167,7 +167,9 @@ def convert_sklearn_hist_gbdt_classifier(operator, device, extra_config):
|
|||
extra_config[constants.BASE_PREDICTION] = base_prediction
|
||||
extra_config[constants.REORDER_TREES] = False
|
||||
|
||||
return convert_gbdt_classifier_common(operator, tree_infos, _get_parameters_hist_gbdt, n_features, n_classes, classes, missing_val=None, extra_config=extra_config)
|
||||
return convert_gbdt_classifier_common(
|
||||
operator, tree_infos, _get_parameters_hist_gbdt, n_features, n_classes, classes, extra_config
|
||||
)
|
||||
|
||||
|
||||
def convert_sklearn_hist_gbdt_regressor(operator, device, extra_config):
|
||||
|
@ -190,7 +192,7 @@ def convert_sklearn_hist_gbdt_regressor(operator, device, extra_config):
|
|||
n_features = operator.raw_operator.n_features_
|
||||
extra_config[constants.BASE_PREDICTION] = [[operator.raw_operator._baseline_prediction]]
|
||||
|
||||
return convert_gbdt_common(operator, tree_infos, _get_parameters_hist_gbdt, n_features, missing_val=None, extra_config=extra_config)
|
||||
return convert_gbdt_common(operator, tree_infos, _get_parameters_hist_gbdt, n_features, None, extra_config)
|
||||
|
||||
|
||||
# Register the converters.
|
||||
|
|
|
@ -239,7 +239,7 @@ def convert_sklearn_isolation_forest(operator, device, extra_config):
|
|||
if tree_type == TreeImpl.gemm:
|
||||
net_parameters = [
|
||||
get_parameters_for_gemm_common(
|
||||
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, n_features, tree_param.missings
|
||||
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, n_features
|
||||
)
|
||||
for tree_param in tree_parameters
|
||||
]
|
||||
|
@ -247,7 +247,7 @@ def convert_sklearn_isolation_forest(operator, device, extra_config):
|
|||
|
||||
net_parameters = [
|
||||
get_parameters_for_tree_trav_sklearn(
|
||||
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, tree_param.missings, classes
|
||||
tree_param.lefts, tree_param.rights, tree_param.features, tree_param.thresholds, tree_param.values, classes
|
||||
)
|
||||
for tree_param in tree_parameters
|
||||
]
|
||||
|
|
|
@ -16,7 +16,7 @@ from ._gbdt_commons import convert_gbdt_classifier_common, convert_gbdt_common
|
|||
from ._tree_commons import TreeParameters
|
||||
|
||||
|
||||
def _tree_traversal(tree_info, lefts, rights, missing, features, thresholds, values):
|
||||
def _tree_traversal(tree_info, lefts, rights, features, thresholds, values):
|
||||
"""
|
||||
Recursive function for parsing a tree and filling the input data structures.
|
||||
"""
|
||||
|
@ -28,7 +28,6 @@ def _tree_traversal(tree_info, lefts, rights, missing, features, thresholds, val
|
|||
values.append([float(tree_info[count].split("=")[1])])
|
||||
lefts.append(-1)
|
||||
rights.append(-1)
|
||||
missing.append(-1)
|
||||
count += 1
|
||||
else:
|
||||
features.append(int(tree_info[count].split(":")[1].split("<")[0].replace("[f", "")))
|
||||
|
@ -57,8 +56,6 @@ def _tree_traversal(tree_info, lefts, rights, missing, features, thresholds, val
|
|||
r_correct_id += 1
|
||||
rights.append(r_correct_id)
|
||||
|
||||
missing_wrong_id = tree_info[count].split(",")[2].replace("missing=", "")
|
||||
missing.append(l_correct_id if l_wrong_id == missing_wrong_id else r_correct_id)
|
||||
count += 1
|
||||
|
||||
|
||||
|
@ -68,15 +65,14 @@ def _get_tree_parameters(tree_info):
|
|||
"""
|
||||
lefts = []
|
||||
rights = []
|
||||
missing = []
|
||||
features = []
|
||||
thresholds = []
|
||||
values = []
|
||||
_tree_traversal(
|
||||
tree_info.replace("[f", "").replace("[", "").replace("]", "").split(), lefts, rights, missing, features, thresholds, values
|
||||
tree_info.replace("[f", "").replace("[", "").replace("]", "").split(), lefts, rights, features, thresholds, values
|
||||
)
|
||||
|
||||
return TreeParameters(lefts, rights, features, thresholds, values, missing)
|
||||
return TreeParameters(lefts, rights, features, thresholds, values)
|
||||
|
||||
|
||||
def convert_sklearn_xgb_classifier(operator, device, extra_config):
|
||||
|
@ -101,9 +97,10 @@ def convert_sklearn_xgb_classifier(operator, device, extra_config):
|
|||
)
|
||||
tree_infos = operator.raw_operator.get_booster().get_dump()
|
||||
n_classes = operator.raw_operator.n_classes_
|
||||
missing_val = operator.raw_operator.missing
|
||||
|
||||
return convert_gbdt_classifier_common(operator, tree_infos, _get_tree_parameters, n_features, n_classes, missing_val=missing_val, extra_config=extra_config)
|
||||
return convert_gbdt_classifier_common(
|
||||
operator, tree_infos, _get_tree_parameters, n_features, n_classes, extra_config=extra_config
|
||||
)
|
||||
|
||||
|
||||
def convert_sklearn_xgb_regressor(operator, device, extra_config):
|
||||
|
@ -135,9 +132,8 @@ def convert_sklearn_xgb_regressor(operator, device, extra_config):
|
|||
base_prediction = [base_prediction]
|
||||
|
||||
extra_config[constants.BASE_PREDICTION] = base_prediction
|
||||
missing_val = operator.raw_operator.missing
|
||||
|
||||
return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, missing_val=missing_val, extra_config=extra_config)
|
||||
return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, extra_config=extra_config)
|
||||
|
||||
|
||||
# Register the converters.
|
||||
|
|
|
@ -9,7 +9,6 @@ import numpy as np
|
|||
import hummingbird.ml
|
||||
from hummingbird.ml._utils import lightgbm_installed, onnx_runtime_installed, tvm_installed
|
||||
from tree_utils import gbdt_implementation_map
|
||||
from sklearn.datasets import make_classification, make_regression
|
||||
|
||||
if lightgbm_installed():
|
||||
import lightgbm as lgb
|
||||
|
@ -259,43 +258,6 @@ class TestLGBMConverter(unittest.TestCase):
|
|||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Missing values test.
|
||||
@unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed")
|
||||
def test_run_lgbm_classifier_w_missing_vals_converter(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
for extra_config_param in ["gemm", "tree_trav", "perf_tree_trav"]:
|
||||
for missing in [None, np.nan]:
|
||||
for model_class, n_classes in zip([lgb.LGBMClassifier, lgb.LGBMClassifier, lgb.LGBMRegressor], [2, 3, None]):
|
||||
model = model_class(use_missing=True, zero_as_missing=False)
|
||||
# Missing values during training + inference.
|
||||
if model_class == lgb.LGBMClassifier:
|
||||
X, y = make_classification(n_samples=100, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021)
|
||||
else:
|
||||
X, y = make_regression(n_samples=100, n_features=3, n_informative=3, random_state=2021)
|
||||
X[:25][y[:25] == 0, 0] = np.nan if missing is None else missing
|
||||
model.fit(X, y)
|
||||
torch_model = hummingbird.ml.convert(model, "torch", X, extra_config={"tree_implementation": extra_config_param})
|
||||
self.assertIsNotNone(torch_model)
|
||||
if model_class == lgb.LGBMClassifier:
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
else:
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Missing values during only inference.
|
||||
model = model_class(use_missing=True, zero_as_missing=False)
|
||||
if model_class == lgb.LGBMClassifier:
|
||||
X, y = make_classification(n_samples=100, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021)
|
||||
else:
|
||||
X, y = make_regression(n_samples=100, n_features=3, n_informative=3, random_state=2021)
|
||||
model.fit(X, y)
|
||||
torch_model = hummingbird.ml.convert(model, "torch", X, extra_config={"tree_implementation": extra_config_param})
|
||||
X[:25][y[:25] == 0, 0] = np.nan if missing is None else missing
|
||||
self.assertIsNotNone(torch_model)
|
||||
if model_class == lgb.LGBMClassifier:
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
else:
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Backend tests.
|
||||
# Test TorchScript backend regression.
|
||||
@unittest.skipIf(not lightgbm_installed(), reason="LightGBM test requires LightGBM installed")
|
||||
|
|
|
@ -10,7 +10,6 @@ import hummingbird.ml
|
|||
from hummingbird.ml._utils import xgboost_installed, tvm_installed
|
||||
from hummingbird.ml import constants
|
||||
from tree_utils import gbdt_implementation_map
|
||||
from sklearn.datasets import make_classification, make_regression
|
||||
|
||||
if xgboost_installed():
|
||||
import xgboost as xgb
|
||||
|
@ -225,43 +224,6 @@ class TestXGBoostConverter(unittest.TestCase):
|
|||
self.assertIsNotNone(torch_model)
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Missing values test.
|
||||
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
|
||||
def test_run_xgb_classifier_w_missing_vals_converter(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
for extra_config_param in ["gemm", "tree_trav", "perf_tree_trav"]:
|
||||
for missing in [None, -99999, np.nan]:
|
||||
for model_class, n_classes in zip([xgb.XGBClassifier, xgb.XGBClassifier, xgb.XGBRegressor], [2, 3, None]):
|
||||
model = model_class(missing=missing)
|
||||
# Missing values during both training and inference.
|
||||
if model_class == xgb.XGBClassifier:
|
||||
X, y = make_classification(n_samples=100, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021)
|
||||
else:
|
||||
X, y = make_regression(n_samples=100, n_features=3, n_informative=3, random_state=2021)
|
||||
X[:25][y[:25] == 0, 0] = np.nan if missing is None else missing
|
||||
model.fit(X, y)
|
||||
torch_model = hummingbird.ml.convert(model, "torch", X, extra_config={"tree_implementation": extra_config_param})
|
||||
self.assertIsNotNone(torch_model)
|
||||
if model_class == xgb.XGBClassifier:
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
else:
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Missing values during only inference.
|
||||
model = model_class(missing=missing)
|
||||
if model_class == xgb.XGBClassifier:
|
||||
X, y = make_classification(n_samples=100, n_features=3, n_informative=3, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021)
|
||||
else:
|
||||
X, y = make_regression(n_samples=100, n_features=3, n_informative=3, random_state=2021)
|
||||
model.fit(X, y)
|
||||
X[:25][y[:25] == 0, 0] = np.nan if missing is None else missing
|
||||
torch_model = hummingbird.ml.convert(model, "torch", X, extra_config={"tree_implementation": extra_config_param})
|
||||
self.assertIsNotNone(torch_model)
|
||||
if model_class == xgb.XGBClassifier:
|
||||
np.testing.assert_allclose(model.predict_proba(X), torch_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
else:
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Torchscript backends.
|
||||
# Test TorchScript backend regression.
|
||||
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
|
||||
|
|
Загрузка…
Ссылка в новой задаче