Fix several numerical precision issues in tree-based models (#511)
* wip * fixing param order * fix atol for regression chain * more fixes * fixing flake issues * test data in float32 * test data in float32 * limiting tree depth for perf_tree_trav * rtol 1e-4 * converting back to float32 after float64 tree operations * better renaming of variables * fix order * wip * explicitly specify dtype * removing .predict test. increasing tolerance * small refactoring * fix missing self * just the tensor without dtype is sufficient * per-init one Co-authored-by: Matteo Interlandi <mainterl@microsoft.com> Co-authored-by: snakanda <snakanda@node.testvm.orion-pg0.wisc.cloudlab.us>
This commit is contained in:
Родитель
86f7674226
Коммит
0c9a71e32a
|
@ -25,7 +25,7 @@ from ._tree_implementations import GEMMGBDTImpl, TreeTraversalGBDTImpl, PerfectT
|
|||
|
||||
|
||||
def convert_gbdt_classifier_common(
|
||||
operator, tree_infos, get_tree_parameters, n_features, n_classes, classes=None, extra_config={}
|
||||
operator, tree_infos, get_tree_parameters, n_features, n_classes, classes=None, extra_config={}, decision_cond="<="
|
||||
):
|
||||
"""
|
||||
Common converter for GBDT classifiers.
|
||||
|
@ -37,6 +37,7 @@ def convert_gbdt_classifier_common(
|
|||
n_classes: How many classes are expected. 1 for regression tasks
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
decision_cond: The condition of the decision nodes in the x <cond> threshold order. Default '<='. Values can be <=, <, >=, >
|
||||
|
||||
Returns:
|
||||
A tree implementation in PyTorch
|
||||
|
@ -66,10 +67,14 @@ def convert_gbdt_classifier_common(
|
|||
if reorder_trees and n_classes > 1:
|
||||
tree_infos = [tree_infos[i * n_classes + j] for j in range(n_classes) for i in range(len(tree_infos) // n_classes)]
|
||||
|
||||
return convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes, extra_config)
|
||||
return convert_gbdt_common(
|
||||
operator, tree_infos, get_tree_parameters, n_features, classes, extra_config=extra_config, decision_cond=decision_cond
|
||||
)
|
||||
|
||||
|
||||
def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes=None, extra_config={}):
|
||||
def convert_gbdt_common(
|
||||
operator, tree_infos, get_tree_parameters, n_features, classes=None, extra_config={}, decision_cond="<="
|
||||
):
|
||||
"""
|
||||
Common converter for GBDT models.
|
||||
|
||||
|
@ -79,6 +84,7 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c
|
|||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
decision_cond: The condition of the decision nodes in the x <cond> threshold order. Default '<='. Values can be <=, <, >=, >
|
||||
|
||||
Returns:
|
||||
A tree implementation in PyTorch
|
||||
|
@ -162,8 +168,14 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c
|
|||
|
||||
# Generate the tree implementation based on the selected strategy.
|
||||
if tree_type == TreeImpl.gemm:
|
||||
return GEMMGBDTImpl(operator, net_parameters, n_features, classes, extra_config)
|
||||
return GEMMGBDTImpl(
|
||||
operator, net_parameters, n_features, classes, extra_config=extra_config, decision_cond=decision_cond
|
||||
)
|
||||
if tree_type == TreeImpl.tree_trav:
|
||||
return TreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, extra_config)
|
||||
return TreeTraversalGBDTImpl(
|
||||
operator, net_parameters, max_depth, n_features, classes, extra_config=extra_config, decision_cond=decision_cond
|
||||
)
|
||||
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav.
|
||||
return PerfectTreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, extra_config)
|
||||
return PerfectTreeTraversalGBDTImpl(
|
||||
operator, net_parameters, max_depth, n_features, classes, extra_config=extra_config, decision_cond=decision_cond
|
||||
)
|
||||
|
|
|
@ -78,9 +78,12 @@ class ApplyBasePredictionPostTransform(PostTransform):
|
|||
|
||||
|
||||
class ApplySigmoidPostTransform(PostTransform):
|
||||
def __init__(self):
|
||||
self.one = torch.tensor(1.0)
|
||||
|
||||
def __call__(self, x):
|
||||
output = torch.sigmoid(x)
|
||||
return torch.cat([1 - output, output], dim=1)
|
||||
return torch.cat([self.one - output, output], dim=1)
|
||||
|
||||
|
||||
class ApplySigmoidBasePredictionPostTransform(PostTransform):
|
||||
|
@ -301,8 +304,8 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val
|
|||
lefts = np.array(lefts)
|
||||
rights = np.array(rights)
|
||||
features = np.array(features)
|
||||
thresholds = np.array(thresholds)
|
||||
values = np.array(values)
|
||||
thresholds = np.array(thresholds, dtype=np.float64)
|
||||
values = np.array(values, dtype=np.float64)
|
||||
|
||||
return [nodes_map, ids, lefts, rights, features, thresholds, values]
|
||||
|
||||
|
@ -373,7 +376,7 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values,
|
|||
hidden_weights.append([1 if i == feature else 0 for i in range(n_features)])
|
||||
hidden_biases.append(thresh)
|
||||
weights.append(np.array(hidden_weights).astype("float32"))
|
||||
biases.append(np.array(hidden_biases).astype("float32"))
|
||||
biases.append(np.array(hidden_biases, dtype=np.float64))
|
||||
|
||||
n_splits = len(hidden_weights)
|
||||
|
||||
|
@ -431,7 +434,7 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values,
|
|||
biases.append(np.array(hidden_biases).astype("float32"))
|
||||
|
||||
# OR neurons from the preceding layer in order to get final classes.
|
||||
weights.append(np.transpose(np.array(class_proba).astype("float32")))
|
||||
weights.append(np.transpose(np.array(class_proba).astype("float64")))
|
||||
biases.append(None)
|
||||
|
||||
return weights, biases
|
||||
|
@ -456,7 +459,7 @@ def convert_decision_ensemble_tree_common(
|
|||
)
|
||||
for tree_param in tree_parameters
|
||||
]
|
||||
return GEMMDecisionTreeImpl(operator, net_parameters, n_features, classes)
|
||||
return GEMMDecisionTreeImpl(operator, net_parameters, n_features, classes, extra_config=extra_config)
|
||||
|
||||
net_parameters = [
|
||||
get_parameters_for_tree_trav(
|
||||
|
@ -467,4 +470,4 @@ def convert_decision_ensemble_tree_common(
|
|||
if tree_type == TreeImpl.tree_trav:
|
||||
return TreeTraversalDecisionTreeImpl(operator, net_parameters, max_depth, n_features, classes, extra_config)
|
||||
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav
|
||||
return PerfectTreeTraversalDecisionTreeImpl(operator, net_parameters, max_depth, n_features, classes)
|
||||
return PerfectTreeTraversalDecisionTreeImpl(operator, net_parameters, max_depth, n_features, classes, extra_config)
|
||||
|
|
|
@ -54,13 +54,16 @@ class AbstractPyTorchTreeImpl(AbstracTreeImpl, torch.nn.Module):
|
|||
Abstract class definig the basic structure for tree-base models implemented in PyTorch.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs):
|
||||
def __init__(
|
||||
self, logical_operator, tree_parameters, n_features, classes, n_classes, decision_cond="<=", extra_config={}, **kwargs
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
n_classes: The total number of used classes
|
||||
decision_cond: The condition of the decision nodes in the x <cond> threshold order. Default '<='. Values can be <=, <, >=, >
|
||||
"""
|
||||
super(AbstractPyTorchTreeImpl, self).__init__(logical_operator, **kwargs)
|
||||
|
||||
|
@ -84,6 +87,29 @@ class AbstractPyTorchTreeImpl(AbstracTreeImpl, torch.nn.Module):
|
|||
self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False)
|
||||
self.perform_class_select = True
|
||||
|
||||
# Set the decision condition.
|
||||
decision_cond_map = {"<=": torch.le, "<": torch.lt, ">=": torch.ge, ">": torch.gt}
|
||||
assert decision_cond in decision_cond_map.keys(), "decision_cond has to be one of:{}".format(
|
||||
",".join(decision_cond_map.keys())
|
||||
)
|
||||
self.decision_cond = decision_cond_map[decision_cond]
|
||||
|
||||
# In some cases float64 is required oterwise we will lose precision.
|
||||
tree_op_precision_dtype = None
|
||||
if constants.TREE_OP_PRECISION_DTYPE in extra_config:
|
||||
tree_op_precision_dtype = extra_config[constants.TREE_OP_PRECISION_DTYPE]
|
||||
assert tree_op_precision_dtype in ["float32", "float64"], "{} has to be of type float32 or float64".format(
|
||||
constants.TREE_OP_PRECISION_DTYPE
|
||||
)
|
||||
else:
|
||||
tree_op_precision_dtype = "float32"
|
||||
self.tree_op_precision_dtype = tree_op_precision_dtype
|
||||
|
||||
# We register also base_prediction here so that tensor will be moved to the proper hardware with the model.
|
||||
# i.e., if cuda is selected, the parameter will be automatically moved on the GPU.
|
||||
if constants.BASE_PREDICTION in extra_config:
|
||||
self.base_prediction = extra_config[constants.BASE_PREDICTION]
|
||||
|
||||
|
||||
class GEMMTreeImpl(AbstractPyTorchTreeImpl):
|
||||
"""
|
||||
|
@ -100,7 +126,9 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
|
|||
"""
|
||||
# If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes.
|
||||
n_classes = n_classes if n_classes is not None else tree_parameters[0][0][2].shape[0]
|
||||
super(GEMMTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs)
|
||||
super(GEMMTreeImpl, self).__init__(
|
||||
logical_operator, tree_parameters, n_features, classes, n_classes, extra_config=extra_config, **kwargs
|
||||
)
|
||||
|
||||
# Initialize the actual model.
|
||||
hidden_one_size = 0
|
||||
|
@ -113,10 +141,10 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
|
|||
|
||||
n_trees = len(tree_parameters)
|
||||
weight_1 = np.zeros((n_trees, hidden_one_size, n_features))
|
||||
bias_1 = np.zeros((n_trees, hidden_one_size))
|
||||
bias_1 = np.zeros((n_trees, hidden_one_size), dtype=np.float64)
|
||||
weight_2 = np.zeros((n_trees, hidden_two_size, hidden_one_size))
|
||||
bias_2 = np.zeros((n_trees, hidden_two_size))
|
||||
weight_3 = np.zeros((n_trees, hidden_three_size, hidden_two_size))
|
||||
weight_3 = np.zeros((n_trees, hidden_three_size, hidden_two_size), dtype=np.float64)
|
||||
|
||||
for i, (weight, bias) in enumerate(tree_parameters):
|
||||
if len(weight[0]) > 0:
|
||||
|
@ -133,24 +161,19 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
|
|||
self.hidden_three_size = hidden_three_size
|
||||
|
||||
self.weight_1 = torch.nn.Parameter(torch.from_numpy(weight_1.reshape(-1, self.n_features).astype("float32")))
|
||||
self.bias_1 = torch.nn.Parameter(torch.from_numpy(bias_1.reshape(-1, 1).astype("float32")))
|
||||
self.bias_1 = torch.nn.Parameter(torch.from_numpy(bias_1.reshape(-1, 1).astype(self.tree_op_precision_dtype)))
|
||||
|
||||
self.weight_2 = torch.nn.Parameter(torch.from_numpy(weight_2.astype("float32")))
|
||||
self.bias_2 = torch.nn.Parameter(torch.from_numpy(bias_2.reshape(-1, 1).astype("float32")))
|
||||
|
||||
self.weight_3 = torch.nn.Parameter(torch.from_numpy(weight_3.astype("float32")))
|
||||
|
||||
# We register also base_prediction here so that tensor will be moved to the proper hardware with the model.
|
||||
# i.e., if cuda is selected, the parameter will be automatically moved on the GPU.
|
||||
if constants.BASE_PREDICTION in extra_config:
|
||||
self.base_prediction = extra_config[constants.BASE_PREDICTION]
|
||||
self.weight_3 = torch.nn.Parameter(torch.from_numpy(weight_3.astype(self.tree_op_precision_dtype)))
|
||||
|
||||
def aggregation(self, x):
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = x.t()
|
||||
x = torch.mm(self.weight_1, x) < self.bias_1
|
||||
x = self.decision_cond(torch.mm(self.weight_1, x), self.bias_1)
|
||||
x = x.view(self.n_trees, self.hidden_one_size, -1)
|
||||
x = x.float()
|
||||
|
||||
|
@ -158,7 +181,10 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
|
|||
|
||||
x = x.view(self.n_trees * self.hidden_two_size, -1) == self.bias_2
|
||||
x = x.view(self.n_trees, self.hidden_two_size, -1)
|
||||
x = x.float()
|
||||
if self.tree_op_precision_dtype == "float32":
|
||||
x = x.float()
|
||||
else:
|
||||
x = x.double()
|
||||
|
||||
x = torch.matmul(self.weight_3, x)
|
||||
x = x.view(self.n_trees, self.hidden_three_size, -1)
|
||||
|
@ -203,7 +229,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
# If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes.
|
||||
n_classes = n_classes if n_classes is not None else tree_parameters[0][6].shape[1]
|
||||
super(TreeTraversalTreeImpl, self).__init__(
|
||||
logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs
|
||||
logical_operator, tree_parameters, n_features, classes, n_classes, extra_config=extra_config, **kwargs
|
||||
)
|
||||
|
||||
# Initialize the actual model.
|
||||
|
@ -216,8 +242,8 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
rights = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64)
|
||||
|
||||
features = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64)
|
||||
thresholds = np.zeros((self.num_trees, self.num_nodes), dtype=np.float32)
|
||||
values = np.zeros((self.num_trees, self.num_nodes, self.n_classes), dtype=np.float32)
|
||||
thresholds = np.zeros((self.num_trees, self.num_nodes), dtype=np.float64)
|
||||
values = np.zeros((self.num_trees, self.num_nodes, self.n_classes), dtype=np.float64)
|
||||
|
||||
for i in range(self.num_trees):
|
||||
lefts[i][: len(tree_parameters[i][0])] = tree_parameters[i][2]
|
||||
|
@ -230,17 +256,14 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
self.rights = torch.nn.Parameter(torch.from_numpy(rights).view(-1), requires_grad=False)
|
||||
|
||||
self.features = torch.nn.Parameter(torch.from_numpy(features).view(-1), requires_grad=False)
|
||||
self.thresholds = torch.nn.Parameter(torch.from_numpy(thresholds).view(-1))
|
||||
self.values = torch.nn.Parameter(torch.from_numpy(values).view(-1, self.n_classes))
|
||||
self.thresholds = torch.nn.Parameter(torch.from_numpy(thresholds.astype(self.tree_op_precision_dtype)).view(-1))
|
||||
self.values = torch.nn.Parameter(
|
||||
torch.from_numpy(values.astype(self.tree_op_precision_dtype)).view(-1, self.n_classes)
|
||||
)
|
||||
|
||||
nodes_offset = [[i * self.num_nodes for i in range(self.num_trees)]]
|
||||
self.nodes_offset = torch.nn.Parameter(torch.LongTensor(nodes_offset), requires_grad=False)
|
||||
|
||||
# We register also base_prediction here so that tensor will be moved to the proper hardware with the model.
|
||||
# i.e., if cuda is selected, the parameter will be automatically moved on the GPU.
|
||||
if constants.BASE_PREDICTION in extra_config:
|
||||
self.base_prediction = extra_config[constants.BASE_PREDICTION]
|
||||
|
||||
def aggregation(self, x):
|
||||
return x
|
||||
|
||||
|
@ -256,7 +279,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
lefts = torch.index_select(self.lefts, 0, indexes).view(-1, self.num_trees)
|
||||
rights = torch.index_select(self.rights, 0, indexes).view(-1, self.num_trees)
|
||||
|
||||
indexes = torch.where(torch.ge(feature_values, thresholds), rights, lefts).long()
|
||||
indexes = torch.where(self.decision_cond(feature_values, thresholds), lefts, rights).long()
|
||||
indexes = indexes + self.nodes_offset
|
||||
indexes = indexes.view(-1)
|
||||
|
||||
|
@ -296,7 +319,7 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
# If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes.
|
||||
n_classes = n_classes if n_classes is not None else tree_parameters[0][6].shape[1]
|
||||
super(PerfectTreeTraversalTreeImpl, self).__init__(
|
||||
logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs
|
||||
logical_operator, tree_parameters, n_features, classes, n_classes, extra_config=extra_config, **kwargs
|
||||
)
|
||||
|
||||
# Initialize the actual model.
|
||||
|
@ -307,7 +330,7 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
node_maps = [tp[0] for tp in tree_parameters]
|
||||
|
||||
weight_0 = np.zeros((self.num_trees, 2 ** max_depth - 1))
|
||||
bias_0 = np.zeros((self.num_trees, 2 ** max_depth - 1))
|
||||
bias_0 = np.zeros((self.num_trees, 2 ** max_depth - 1), dtype=np.float64)
|
||||
weight_1 = np.zeros((self.num_trees, 2 ** max_depth, self.n_classes))
|
||||
|
||||
for i, node_map in enumerate(node_maps):
|
||||
|
@ -315,9 +338,10 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
|
||||
node_by_levels = [set() for _ in range(max_depth)]
|
||||
self._traverse_by_level(node_by_levels, 0, -1, max_depth)
|
||||
|
||||
self.root_nodes = torch.nn.Parameter(torch.from_numpy(weight_0[:, 0].flatten().astype("int64")), requires_grad=False)
|
||||
self.root_biases = torch.nn.Parameter(-1 * torch.from_numpy(bias_0[:, 0].astype("float32")), requires_grad=False)
|
||||
self.root_biases = torch.nn.Parameter(
|
||||
torch.from_numpy(bias_0[:, 0].astype(self.tree_op_precision_dtype)), requires_grad=False
|
||||
)
|
||||
|
||||
tree_indices = np.array([i for i in range(0, 2 * self.num_trees, 2)]).astype("int64")
|
||||
self.tree_indices = torch.nn.Parameter(torch.from_numpy(tree_indices), requires_grad=False)
|
||||
|
@ -329,7 +353,7 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
torch.from_numpy(weight_0[:, list(sorted(node_by_levels[i]))].flatten().astype("int64")), requires_grad=False
|
||||
)
|
||||
biases = torch.nn.Parameter(
|
||||
torch.from_numpy(-1 * bias_0[:, list(sorted(node_by_levels[i]))].flatten().astype("float32")),
|
||||
torch.from_numpy(bias_0[:, list(sorted(node_by_levels[i]))].flatten().astype(self.tree_op_precision_dtype)),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.nodes.append(nodes)
|
||||
|
@ -339,19 +363,14 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
self.biases = torch.nn.ParameterList(self.biases)
|
||||
|
||||
self.leaf_nodes = torch.nn.Parameter(
|
||||
torch.from_numpy(weight_1.reshape((-1, self.n_classes)).astype("float32")), requires_grad=False
|
||||
torch.from_numpy(weight_1.reshape((-1, self.n_classes)).astype(self.tree_op_precision_dtype)), requires_grad=False
|
||||
)
|
||||
|
||||
# We register also base_prediction here so that tensor will be moved to the proper hardware with the model.
|
||||
# i.e., if cuda is selected, the parameter will be automatically moved on the GPU.
|
||||
if constants.BASE_PREDICTION in extra_config:
|
||||
self.base_prediction = extra_config[constants.BASE_PREDICTION]
|
||||
|
||||
def aggregation(self, x):
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
prev_indices = (torch.ge(torch.index_select(x, 1, self.root_nodes), self.root_biases)).long()
|
||||
prev_indices = (self.decision_cond(torch.index_select(x, 1, self.root_nodes), self.root_biases)).long()
|
||||
prev_indices = prev_indices + self.tree_indices
|
||||
prev_indices = prev_indices.view(-1)
|
||||
|
||||
|
@ -359,7 +378,9 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
for nodes, biases in zip(self.nodes, self.biases):
|
||||
gather_indices = torch.index_select(nodes, 0, prev_indices).view(-1, self.num_trees)
|
||||
features = torch.gather(x, 1, gather_indices).view(-1)
|
||||
prev_indices = factor * prev_indices + torch.ge(features, torch.index_select(biases, 0, prev_indices)).long()
|
||||
prev_indices = (
|
||||
factor * prev_indices + self.decision_cond(features, torch.index_select(biases, 0, prev_indices)).long()
|
||||
)
|
||||
|
||||
output = torch.index_select(self.leaf_nodes, 0, prev_indices).view(-1, self.num_trees, self.n_classes)
|
||||
|
||||
|
@ -390,20 +411,11 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
def _get_weights_and_biases(self, nodes_map, tree_depth, weight_0, weight_1, bias_0):
|
||||
def depth_f_traversal(node, current_depth, node_id, leaf_start_id):
|
||||
weight_0[node_id] = node.feature
|
||||
bias_0[node_id] = -node.threshold
|
||||
bias_0[node_id] = node.threshold
|
||||
current_depth += 1
|
||||
node_id += 1
|
||||
|
||||
if node.left.feature == -1:
|
||||
node_id += 2 ** (tree_depth - current_depth - 1) - 1
|
||||
v = node.left.value
|
||||
weight_1[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = (
|
||||
np.ones((2 ** (tree_depth - current_depth - 1), self.n_classes)) * v
|
||||
)
|
||||
leaf_start_id += 2 ** (tree_depth - current_depth - 1)
|
||||
else:
|
||||
node_id, leaf_start_id = depth_f_traversal(node.left, current_depth, node_id, leaf_start_id)
|
||||
|
||||
# Condition false (right sub-tree)
|
||||
if node.right.feature == -1:
|
||||
node_id += 2 ** (tree_depth - current_depth - 1) - 1
|
||||
v = node.right.value
|
||||
|
@ -414,6 +426,17 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
|
|||
else:
|
||||
node_id, leaf_start_id = depth_f_traversal(node.right, current_depth, node_id, leaf_start_id)
|
||||
|
||||
# Condition true (left sub-tree)
|
||||
if node.left.feature == -1:
|
||||
node_id += 2 ** (tree_depth - current_depth - 1) - 1
|
||||
v = node.left.value
|
||||
weight_1[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = (
|
||||
np.ones((2 ** (tree_depth - current_depth - 1), self.n_classes)) * v
|
||||
)
|
||||
leaf_start_id += 2 ** (tree_depth - current_depth - 1)
|
||||
else:
|
||||
node_id, leaf_start_id = depth_f_traversal(node.left, current_depth, node_id, leaf_start_id)
|
||||
|
||||
return node_id, leaf_start_id
|
||||
|
||||
depth_f_traversal(nodes_map[0], -1, 0, 0)
|
||||
|
@ -426,14 +449,17 @@ class GEMMDecisionTreeImpl(GEMMTreeImpl):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes=None):
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes=None, extra_config={}):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(GEMMDecisionTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes)
|
||||
super(GEMMDecisionTreeImpl, self).__init__(
|
||||
logical_operator, tree_parameters, n_features, classes, extra_config=extra_config
|
||||
)
|
||||
|
||||
def aggregation(self, x):
|
||||
output = x.sum(0).t()
|
||||
|
@ -446,7 +472,7 @@ class TreeTraversalDecisionTreeImpl(TreeTraversalTreeImpl):
|
|||
Class implementing the Tree Traversal strategy in PyTorch for decision tree models.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}):
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
|
@ -456,7 +482,7 @@ class TreeTraversalDecisionTreeImpl(TreeTraversalTreeImpl):
|
|||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(TreeTraversalDecisionTreeImpl, self).__init__(
|
||||
logical_operator, tree_parameters, max_depth, n_features, classes, extra_config=extra_config
|
||||
logical_operator, tree_parameters, max_depth, n_features, classes, extra_config=extra_config, **kwargs
|
||||
)
|
||||
|
||||
def aggregation(self, x):
|
||||
|
@ -470,16 +496,17 @@ class PerfectTreeTraversalDecisionTreeImpl(PerfectTreeTraversalTreeImpl):
|
|||
Class implementing the Perfect Tree Traversal strategy in PyTorch for decision tree models.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None):
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
max_depth: The maximum tree-depth in the model
|
||||
n_features: The number of features input to the model
|
||||
classes: The classes used for classification. None if implementing a regression model
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(PerfectTreeTraversalDecisionTreeImpl, self).__init__(
|
||||
logical_operator, tree_parameters, max_depth, n_features, classes
|
||||
logical_operator, tree_parameters, max_depth, n_features, classes, extra_config=extra_config, **kwargs
|
||||
)
|
||||
|
||||
def aggregation(self, x):
|
||||
|
@ -494,7 +521,7 @@ class GEMMGBDTImpl(GEMMTreeImpl):
|
|||
Class implementing the GEMM strategy (in PyTorch) for GBDT models.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes=None, extra_config={}):
|
||||
def __init__(self, logical_operator, tree_parameters, n_features, classes=None, extra_config={}, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
|
@ -502,7 +529,7 @@ class GEMMGBDTImpl(GEMMTreeImpl):
|
|||
classes: The classes used for classification. None if implementing a regression model
|
||||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(GEMMGBDTImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, 1, extra_config)
|
||||
super(GEMMGBDTImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, 1, extra_config, **kwargs)
|
||||
|
||||
self.n_gbdt_classes = 1
|
||||
self.post_transform = _tree_commons.PostTransform()
|
||||
|
@ -526,7 +553,7 @@ class TreeTraversalGBDTImpl(TreeTraversalTreeImpl):
|
|||
Class implementing the Tree Traversal strategy in PyTorch.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, max_detph, n_features, classes=None, extra_config={}):
|
||||
def __init__(self, logical_operator, tree_parameters, max_detph, n_features, classes=None, extra_config={}, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
|
@ -536,7 +563,7 @@ class TreeTraversalGBDTImpl(TreeTraversalTreeImpl):
|
|||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(TreeTraversalGBDTImpl, self).__init__(
|
||||
logical_operator, tree_parameters, max_detph, n_features, classes, 1, extra_config
|
||||
logical_operator, tree_parameters, max_detph, n_features, classes, 1, extra_config, **kwargs
|
||||
)
|
||||
|
||||
self.n_gbdt_classes = 1
|
||||
|
@ -561,7 +588,7 @@ class PerfectTreeTraversalGBDTImpl(PerfectTreeTraversalTreeImpl):
|
|||
Class implementing the Perfect Tree Traversal strategy in PyTorch.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}):
|
||||
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
tree_parameters: The parameters defining the tree structure
|
||||
|
@ -571,7 +598,7 @@ class PerfectTreeTraversalGBDTImpl(PerfectTreeTraversalTreeImpl):
|
|||
extra_config: Extra configuration used to properly implement the source tree
|
||||
"""
|
||||
super(PerfectTreeTraversalGBDTImpl, self).__init__(
|
||||
logical_operator, tree_parameters, max_depth, n_features, classes, 1, extra_config
|
||||
logical_operator, tree_parameters, max_depth, n_features, classes, 1, extra_config, **kwargs
|
||||
)
|
||||
|
||||
self.n_gbdt_classes = 1
|
||||
|
|
|
@ -99,7 +99,7 @@ def convert_sklearn_xgb_classifier(operator, device, extra_config):
|
|||
n_classes = operator.raw_operator.n_classes_
|
||||
|
||||
return convert_gbdt_classifier_common(
|
||||
operator, tree_infos, _get_tree_parameters, n_features, n_classes, extra_config=extra_config
|
||||
operator, tree_infos, _get_tree_parameters, n_features, n_classes, decision_cond="<", extra_config=extra_config
|
||||
)
|
||||
|
||||
|
||||
|
@ -133,7 +133,9 @@ def convert_sklearn_xgb_regressor(operator, device, extra_config):
|
|||
|
||||
extra_config[constants.BASE_PREDICTION] = base_prediction
|
||||
|
||||
return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, extra_config=extra_config)
|
||||
return convert_gbdt_common(
|
||||
operator, tree_infos, _get_tree_parameters, n_features, decision_cond="<", extra_config=extra_config
|
||||
)
|
||||
|
||||
|
||||
# Register the converters.
|
||||
|
|
|
@ -469,6 +469,9 @@ backends = _build_backend_map()
|
|||
TREE_IMPLEMENTATION = "tree_implementation"
|
||||
"""Which tree implementation to use. Values can be: gemm, tree_trav, perf_tree_trav."""
|
||||
|
||||
TREE_OP_PRECISION_DTYPE = "tree_op_precision_dtype"
|
||||
"""Which data type to be used for the threshold and leaf values of decision nodes. Values can be: float32 or float64."""
|
||||
|
||||
ONNX_OUTPUT_MODEL_NAME = "onnx_model_name"
|
||||
"""For ONNX models we can set the name of the output model."""
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import warnings
|
|||
import numpy as np
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml import constants
|
||||
from hummingbird.ml._utils import lightgbm_installed, onnx_runtime_installed, tvm_installed
|
||||
from tree_utils import gbdt_implementation_map
|
||||
|
||||
|
@ -362,11 +363,15 @@ class TestLGBMConverter(unittest.TestCase):
|
|||
model.fit(X, y)
|
||||
|
||||
# Create TVM model.
|
||||
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={"tree_implementation": tree_implementation})
|
||||
tvm_model = hummingbird.ml.convert(
|
||||
model,
|
||||
"tvm",
|
||||
X,
|
||||
extra_config={constants.TREE_IMPLEMENTATION: tree_implementation, constants.TREE_OP_PRECISION_DTYPE: "float64"}
|
||||
)
|
||||
|
||||
# Check results.
|
||||
np.testing.assert_allclose(tvm_model.predict(X), model.predict(X))
|
||||
np.testing.assert_allclose(tvm_model.predict_proba(X), model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
np.testing.assert_allclose(tvm_model.predict_proba(X), model.predict_proba(X), rtol=1e-04, atol=1e-04)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -15,6 +15,8 @@ from hummingbird.ml._utils import tvm_installed
|
|||
from hummingbird.ml import constants
|
||||
from tree_utils import dt_implementation_map
|
||||
|
||||
import random
|
||||
|
||||
|
||||
class TestSklearnTreeConverter(unittest.TestCase):
|
||||
# Check tree implementation
|
||||
|
@ -725,17 +727,25 @@ class TestSklearnTreeConverter(unittest.TestCase):
|
|||
for tree_method in ["gemm", "tree_trav", "perf_tree_trav"]:
|
||||
for n_targets in [1, 2, 7]:
|
||||
for tree_class in [DecisionTreeRegressor, ExtraTreesRegressor, RandomForestRegressor]:
|
||||
model = tree_class()
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
if tree_method == "perf_tree_trav":
|
||||
model = tree_class(random_state=seed, max_depth=10)
|
||||
else:
|
||||
model = tree_class(random_state=seed)
|
||||
X, y = datasets.make_regression(
|
||||
n_samples=100, n_features=10, n_informative=5, n_targets=n_targets, random_state=2021
|
||||
n_samples=100, n_features=10, n_informative=5, n_targets=n_targets, random_state=seed
|
||||
)
|
||||
model.fit(X, y)
|
||||
X = X.astype('float32')
|
||||
y = y.astype('float32')
|
||||
|
||||
torch_model = hummingbird.ml.convert(
|
||||
model, "torch", extra_config={constants.TREE_IMPLEMENTATION: tree_method}
|
||||
model,
|
||||
"torch",
|
||||
extra_config={constants.TREE_IMPLEMENTATION: tree_method, constants.TREE_OP_PRECISION_DTYPE: "float64"}
|
||||
)
|
||||
self.assertTrue(torch_model is not None)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-5, atol=1e-5)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-5, atol=1e-5, err_msg="{}/{}/{}/{}".format(tree_method, n_targets, tree_class, seed))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -17,46 +17,53 @@ from sklearn.svm import LinearSVR
|
|||
from sklearn.multioutput import MultiOutputRegressor, RegressorChain
|
||||
|
||||
import hummingbird.ml
|
||||
from hummingbird.ml import constants
|
||||
|
||||
import random
|
||||
|
||||
random.seed(2021)
|
||||
|
||||
|
||||
class TestSklearnMultioutputRegressor(unittest.TestCase):
|
||||
# Test MultiOutputRegressor with different child learners
|
||||
def test_sklearn_multioutput_regressor(self):
|
||||
for n_targets in [2, 3, 4]:
|
||||
for model_class in [DecisionTreeRegressor, ExtraTreesRegressor, RandomForestRegressor, LinearRegression]:
|
||||
model = MultiOutputRegressor(model_class())
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
if model_class != LinearRegression:
|
||||
model = MultiOutputRegressor(model_class(random_state=seed))
|
||||
else:
|
||||
model = MultiOutputRegressor(model_class())
|
||||
X, y = datasets.make_regression(
|
||||
n_samples=50, n_features=10, n_informative=5, n_targets=n_targets, random_state=2020
|
||||
n_samples=50, n_features=10, n_informative=5, n_targets=n_targets, random_state=seed
|
||||
)
|
||||
X = X.astype("float32")
|
||||
y = y.astype("float32")
|
||||
model.fit(X, y)
|
||||
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
torch_model = hummingbird.ml.convert(model, "torch", extra_config={constants.TREE_OP_PRECISION_DTYPE: "float64"})
|
||||
self.assertTrue(torch_model is not None)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-5, atol=1e-4)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-5, atol=1e-4, err_msg="{}/{}/{}".format(n_targets, model_class, seed))
|
||||
|
||||
# Test RegressorChain with different child learners
|
||||
def test_sklearn_regressor_chain(self):
|
||||
for n_targets in [2, 3, 4]:
|
||||
for model_class in [DecisionTreeRegressor, ExtraTreesRegressor, RandomForestRegressor, LinearRegression]:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
order = [i for i in range(n_targets)]
|
||||
random.shuffle(order)
|
||||
model = RegressorChain(model_class(), order=order)
|
||||
random.Random(seed).shuffle(order)
|
||||
if model_class != LinearRegression:
|
||||
model = RegressorChain(model_class(random_state=seed), order=order)
|
||||
else:
|
||||
model = RegressorChain(model_class(), order=order)
|
||||
X, y = datasets.make_regression(
|
||||
n_samples=50, n_features=10, n_informative=5, n_targets=n_targets, random_state=2021
|
||||
n_samples=50, n_features=10, n_informative=5, n_targets=n_targets, random_state=seed
|
||||
)
|
||||
X = X.astype("float32")
|
||||
y = y.astype("float32")
|
||||
model.fit(X, y)
|
||||
|
||||
torch_model = hummingbird.ml.convert(model, "torch")
|
||||
torch_model = hummingbird.ml.convert(model, "torch", extra_config={constants.TREE_OP_PRECISION_DTYPE: "float64"})
|
||||
self.assertTrue(torch_model is not None)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-5, atol=1e-5)
|
||||
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-4, atol=1e-4, err_msg="{}/{}/{}".format(n_targets, model_class, seed))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Загрузка…
Ссылка в новой задаче