Fix several numerical precision issues in tree-based models (#511)

* wip

* fixing param order

* fix atol for regression chain

* more fixes

* fixing flake issues

* test data in float32

* test data in float32

* limiting tree depth for perf_tree_trav

* rtol 1e-4

* converting back to float32 after float64 tree operations

* better renaming of variables

* fix order

* wip

* explicitly specify dtype

* removing .predict test. increasing tolerance

* small refactoring

* fix missing self

* just the tensor without dtype is sufficient

* per-init one

Co-authored-by: Matteo Interlandi <mainterl@microsoft.com>
Co-authored-by: snakanda <snakanda@node.testvm.orion-pg0.wisc.cloudlab.us>
This commit is contained in:
Supun Nakandala 2021-06-15 13:23:43 -07:00 коммит произвёл GitHub
Родитель 86f7674226
Коммит 0c9a71e32a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 162 добавлений и 93 удалений

Просмотреть файл

@ -25,7 +25,7 @@ from ._tree_implementations import GEMMGBDTImpl, TreeTraversalGBDTImpl, PerfectT
def convert_gbdt_classifier_common(
operator, tree_infos, get_tree_parameters, n_features, n_classes, classes=None, extra_config={}
operator, tree_infos, get_tree_parameters, n_features, n_classes, classes=None, extra_config={}, decision_cond="<="
):
"""
Common converter for GBDT classifiers.
@ -37,6 +37,7 @@ def convert_gbdt_classifier_common(
n_classes: How many classes are expected. 1 for regression tasks
classes: The classes used for classification. None if implementing a regression model
extra_config: Extra configuration used to properly implement the source tree
decision_cond: The condition of the decision nodes in the x <cond> threshold order. Default '<='. Values can be <=, <, >=, >
Returns:
A tree implementation in PyTorch
@ -66,10 +67,14 @@ def convert_gbdt_classifier_common(
if reorder_trees and n_classes > 1:
tree_infos = [tree_infos[i * n_classes + j] for j in range(n_classes) for i in range(len(tree_infos) // n_classes)]
return convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes, extra_config)
return convert_gbdt_common(
operator, tree_infos, get_tree_parameters, n_features, classes, extra_config=extra_config, decision_cond=decision_cond
)
def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes=None, extra_config={}):
def convert_gbdt_common(
operator, tree_infos, get_tree_parameters, n_features, classes=None, extra_config={}, decision_cond="<="
):
"""
Common converter for GBDT models.
@ -79,6 +84,7 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
extra_config: Extra configuration used to properly implement the source tree
decision_cond: The condition of the decision nodes in the x <cond> threshold order. Default '<='. Values can be <=, <, >=, >
Returns:
A tree implementation in PyTorch
@ -162,8 +168,14 @@ def convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, c
# Generate the tree implementation based on the selected strategy.
if tree_type == TreeImpl.gemm:
return GEMMGBDTImpl(operator, net_parameters, n_features, classes, extra_config)
return GEMMGBDTImpl(
operator, net_parameters, n_features, classes, extra_config=extra_config, decision_cond=decision_cond
)
if tree_type == TreeImpl.tree_trav:
return TreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, extra_config)
return TreeTraversalGBDTImpl(
operator, net_parameters, max_depth, n_features, classes, extra_config=extra_config, decision_cond=decision_cond
)
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav.
return PerfectTreeTraversalGBDTImpl(operator, net_parameters, max_depth, n_features, classes, extra_config)
return PerfectTreeTraversalGBDTImpl(
operator, net_parameters, max_depth, n_features, classes, extra_config=extra_config, decision_cond=decision_cond
)

Просмотреть файл

@ -78,9 +78,12 @@ class ApplyBasePredictionPostTransform(PostTransform):
class ApplySigmoidPostTransform(PostTransform):
def __init__(self):
self.one = torch.tensor(1.0)
def __call__(self, x):
output = torch.sigmoid(x)
return torch.cat([1 - output, output], dim=1)
return torch.cat([self.one - output, output], dim=1)
class ApplySigmoidBasePredictionPostTransform(PostTransform):
@ -301,8 +304,8 @@ def get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, val
lefts = np.array(lefts)
rights = np.array(rights)
features = np.array(features)
thresholds = np.array(thresholds)
values = np.array(values)
thresholds = np.array(thresholds, dtype=np.float64)
values = np.array(values, dtype=np.float64)
return [nodes_map, ids, lefts, rights, features, thresholds, values]
@ -373,7 +376,7 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values,
hidden_weights.append([1 if i == feature else 0 for i in range(n_features)])
hidden_biases.append(thresh)
weights.append(np.array(hidden_weights).astype("float32"))
biases.append(np.array(hidden_biases).astype("float32"))
biases.append(np.array(hidden_biases, dtype=np.float64))
n_splits = len(hidden_weights)
@ -431,7 +434,7 @@ def get_parameters_for_gemm_common(lefts, rights, features, thresholds, values,
biases.append(np.array(hidden_biases).astype("float32"))
# OR neurons from the preceding layer in order to get final classes.
weights.append(np.transpose(np.array(class_proba).astype("float32")))
weights.append(np.transpose(np.array(class_proba).astype("float64")))
biases.append(None)
return weights, biases
@ -456,7 +459,7 @@ def convert_decision_ensemble_tree_common(
)
for tree_param in tree_parameters
]
return GEMMDecisionTreeImpl(operator, net_parameters, n_features, classes)
return GEMMDecisionTreeImpl(operator, net_parameters, n_features, classes, extra_config=extra_config)
net_parameters = [
get_parameters_for_tree_trav(
@ -467,4 +470,4 @@ def convert_decision_ensemble_tree_common(
if tree_type == TreeImpl.tree_trav:
return TreeTraversalDecisionTreeImpl(operator, net_parameters, max_depth, n_features, classes, extra_config)
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav
return PerfectTreeTraversalDecisionTreeImpl(operator, net_parameters, max_depth, n_features, classes)
return PerfectTreeTraversalDecisionTreeImpl(operator, net_parameters, max_depth, n_features, classes, extra_config)

Просмотреть файл

@ -54,13 +54,16 @@ class AbstractPyTorchTreeImpl(AbstracTreeImpl, torch.nn.Module):
Abstract class definig the basic structure for tree-base models implemented in PyTorch.
"""
def __init__(self, logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs):
def __init__(
self, logical_operator, tree_parameters, n_features, classes, n_classes, decision_cond="<=", extra_config={}, **kwargs
):
"""
Args:
tree_parameters: The parameters defining the tree structure
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
n_classes: The total number of used classes
decision_cond: The condition of the decision nodes in the x <cond> threshold order. Default '<='. Values can be <=, <, >=, >
"""
super(AbstractPyTorchTreeImpl, self).__init__(logical_operator, **kwargs)
@ -84,6 +87,29 @@ class AbstractPyTorchTreeImpl(AbstracTreeImpl, torch.nn.Module):
self.classes = torch.nn.Parameter(torch.IntTensor(classes), requires_grad=False)
self.perform_class_select = True
# Set the decision condition.
decision_cond_map = {"<=": torch.le, "<": torch.lt, ">=": torch.ge, ">": torch.gt}
assert decision_cond in decision_cond_map.keys(), "decision_cond has to be one of:{}".format(
",".join(decision_cond_map.keys())
)
self.decision_cond = decision_cond_map[decision_cond]
# In some cases float64 is required oterwise we will lose precision.
tree_op_precision_dtype = None
if constants.TREE_OP_PRECISION_DTYPE in extra_config:
tree_op_precision_dtype = extra_config[constants.TREE_OP_PRECISION_DTYPE]
assert tree_op_precision_dtype in ["float32", "float64"], "{} has to be of type float32 or float64".format(
constants.TREE_OP_PRECISION_DTYPE
)
else:
tree_op_precision_dtype = "float32"
self.tree_op_precision_dtype = tree_op_precision_dtype
# We register also base_prediction here so that tensor will be moved to the proper hardware with the model.
# i.e., if cuda is selected, the parameter will be automatically moved on the GPU.
if constants.BASE_PREDICTION in extra_config:
self.base_prediction = extra_config[constants.BASE_PREDICTION]
class GEMMTreeImpl(AbstractPyTorchTreeImpl):
"""
@ -100,7 +126,9 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
"""
# If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes.
n_classes = n_classes if n_classes is not None else tree_parameters[0][0][2].shape[0]
super(GEMMTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs)
super(GEMMTreeImpl, self).__init__(
logical_operator, tree_parameters, n_features, classes, n_classes, extra_config=extra_config, **kwargs
)
# Initialize the actual model.
hidden_one_size = 0
@ -113,10 +141,10 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
n_trees = len(tree_parameters)
weight_1 = np.zeros((n_trees, hidden_one_size, n_features))
bias_1 = np.zeros((n_trees, hidden_one_size))
bias_1 = np.zeros((n_trees, hidden_one_size), dtype=np.float64)
weight_2 = np.zeros((n_trees, hidden_two_size, hidden_one_size))
bias_2 = np.zeros((n_trees, hidden_two_size))
weight_3 = np.zeros((n_trees, hidden_three_size, hidden_two_size))
weight_3 = np.zeros((n_trees, hidden_three_size, hidden_two_size), dtype=np.float64)
for i, (weight, bias) in enumerate(tree_parameters):
if len(weight[0]) > 0:
@ -133,24 +161,19 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
self.hidden_three_size = hidden_three_size
self.weight_1 = torch.nn.Parameter(torch.from_numpy(weight_1.reshape(-1, self.n_features).astype("float32")))
self.bias_1 = torch.nn.Parameter(torch.from_numpy(bias_1.reshape(-1, 1).astype("float32")))
self.bias_1 = torch.nn.Parameter(torch.from_numpy(bias_1.reshape(-1, 1).astype(self.tree_op_precision_dtype)))
self.weight_2 = torch.nn.Parameter(torch.from_numpy(weight_2.astype("float32")))
self.bias_2 = torch.nn.Parameter(torch.from_numpy(bias_2.reshape(-1, 1).astype("float32")))
self.weight_3 = torch.nn.Parameter(torch.from_numpy(weight_3.astype("float32")))
# We register also base_prediction here so that tensor will be moved to the proper hardware with the model.
# i.e., if cuda is selected, the parameter will be automatically moved on the GPU.
if constants.BASE_PREDICTION in extra_config:
self.base_prediction = extra_config[constants.BASE_PREDICTION]
self.weight_3 = torch.nn.Parameter(torch.from_numpy(weight_3.astype(self.tree_op_precision_dtype)))
def aggregation(self, x):
return x
def forward(self, x):
x = x.t()
x = torch.mm(self.weight_1, x) < self.bias_1
x = self.decision_cond(torch.mm(self.weight_1, x), self.bias_1)
x = x.view(self.n_trees, self.hidden_one_size, -1)
x = x.float()
@ -158,7 +181,10 @@ class GEMMTreeImpl(AbstractPyTorchTreeImpl):
x = x.view(self.n_trees * self.hidden_two_size, -1) == self.bias_2
x = x.view(self.n_trees, self.hidden_two_size, -1)
x = x.float()
if self.tree_op_precision_dtype == "float32":
x = x.float()
else:
x = x.double()
x = torch.matmul(self.weight_3, x)
x = x.view(self.n_trees, self.hidden_three_size, -1)
@ -203,7 +229,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
# If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes.
n_classes = n_classes if n_classes is not None else tree_parameters[0][6].shape[1]
super(TreeTraversalTreeImpl, self).__init__(
logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs
logical_operator, tree_parameters, n_features, classes, n_classes, extra_config=extra_config, **kwargs
)
# Initialize the actual model.
@ -216,8 +242,8 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
rights = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64)
features = np.zeros((self.num_trees, self.num_nodes), dtype=np.int64)
thresholds = np.zeros((self.num_trees, self.num_nodes), dtype=np.float32)
values = np.zeros((self.num_trees, self.num_nodes, self.n_classes), dtype=np.float32)
thresholds = np.zeros((self.num_trees, self.num_nodes), dtype=np.float64)
values = np.zeros((self.num_trees, self.num_nodes, self.n_classes), dtype=np.float64)
for i in range(self.num_trees):
lefts[i][: len(tree_parameters[i][0])] = tree_parameters[i][2]
@ -230,17 +256,14 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
self.rights = torch.nn.Parameter(torch.from_numpy(rights).view(-1), requires_grad=False)
self.features = torch.nn.Parameter(torch.from_numpy(features).view(-1), requires_grad=False)
self.thresholds = torch.nn.Parameter(torch.from_numpy(thresholds).view(-1))
self.values = torch.nn.Parameter(torch.from_numpy(values).view(-1, self.n_classes))
self.thresholds = torch.nn.Parameter(torch.from_numpy(thresholds.astype(self.tree_op_precision_dtype)).view(-1))
self.values = torch.nn.Parameter(
torch.from_numpy(values.astype(self.tree_op_precision_dtype)).view(-1, self.n_classes)
)
nodes_offset = [[i * self.num_nodes for i in range(self.num_trees)]]
self.nodes_offset = torch.nn.Parameter(torch.LongTensor(nodes_offset), requires_grad=False)
# We register also base_prediction here so that tensor will be moved to the proper hardware with the model.
# i.e., if cuda is selected, the parameter will be automatically moved on the GPU.
if constants.BASE_PREDICTION in extra_config:
self.base_prediction = extra_config[constants.BASE_PREDICTION]
def aggregation(self, x):
return x
@ -256,7 +279,7 @@ class TreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
lefts = torch.index_select(self.lefts, 0, indexes).view(-1, self.num_trees)
rights = torch.index_select(self.rights, 0, indexes).view(-1, self.num_trees)
indexes = torch.where(torch.ge(feature_values, thresholds), rights, lefts).long()
indexes = torch.where(self.decision_cond(feature_values, thresholds), lefts, rights).long()
indexes = indexes + self.nodes_offset
indexes = indexes.view(-1)
@ -296,7 +319,7 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
# If n_classes is not provided we induce it from tree parameters. Multioutput regression targets are also treated as separate classes.
n_classes = n_classes if n_classes is not None else tree_parameters[0][6].shape[1]
super(PerfectTreeTraversalTreeImpl, self).__init__(
logical_operator, tree_parameters, n_features, classes, n_classes, **kwargs
logical_operator, tree_parameters, n_features, classes, n_classes, extra_config=extra_config, **kwargs
)
# Initialize the actual model.
@ -307,7 +330,7 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
node_maps = [tp[0] for tp in tree_parameters]
weight_0 = np.zeros((self.num_trees, 2 ** max_depth - 1))
bias_0 = np.zeros((self.num_trees, 2 ** max_depth - 1))
bias_0 = np.zeros((self.num_trees, 2 ** max_depth - 1), dtype=np.float64)
weight_1 = np.zeros((self.num_trees, 2 ** max_depth, self.n_classes))
for i, node_map in enumerate(node_maps):
@ -315,9 +338,10 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
node_by_levels = [set() for _ in range(max_depth)]
self._traverse_by_level(node_by_levels, 0, -1, max_depth)
self.root_nodes = torch.nn.Parameter(torch.from_numpy(weight_0[:, 0].flatten().astype("int64")), requires_grad=False)
self.root_biases = torch.nn.Parameter(-1 * torch.from_numpy(bias_0[:, 0].astype("float32")), requires_grad=False)
self.root_biases = torch.nn.Parameter(
torch.from_numpy(bias_0[:, 0].astype(self.tree_op_precision_dtype)), requires_grad=False
)
tree_indices = np.array([i for i in range(0, 2 * self.num_trees, 2)]).astype("int64")
self.tree_indices = torch.nn.Parameter(torch.from_numpy(tree_indices), requires_grad=False)
@ -329,7 +353,7 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
torch.from_numpy(weight_0[:, list(sorted(node_by_levels[i]))].flatten().astype("int64")), requires_grad=False
)
biases = torch.nn.Parameter(
torch.from_numpy(-1 * bias_0[:, list(sorted(node_by_levels[i]))].flatten().astype("float32")),
torch.from_numpy(bias_0[:, list(sorted(node_by_levels[i]))].flatten().astype(self.tree_op_precision_dtype)),
requires_grad=False,
)
self.nodes.append(nodes)
@ -339,19 +363,14 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
self.biases = torch.nn.ParameterList(self.biases)
self.leaf_nodes = torch.nn.Parameter(
torch.from_numpy(weight_1.reshape((-1, self.n_classes)).astype("float32")), requires_grad=False
torch.from_numpy(weight_1.reshape((-1, self.n_classes)).astype(self.tree_op_precision_dtype)), requires_grad=False
)
# We register also base_prediction here so that tensor will be moved to the proper hardware with the model.
# i.e., if cuda is selected, the parameter will be automatically moved on the GPU.
if constants.BASE_PREDICTION in extra_config:
self.base_prediction = extra_config[constants.BASE_PREDICTION]
def aggregation(self, x):
return x
def forward(self, x):
prev_indices = (torch.ge(torch.index_select(x, 1, self.root_nodes), self.root_biases)).long()
prev_indices = (self.decision_cond(torch.index_select(x, 1, self.root_nodes), self.root_biases)).long()
prev_indices = prev_indices + self.tree_indices
prev_indices = prev_indices.view(-1)
@ -359,7 +378,9 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
for nodes, biases in zip(self.nodes, self.biases):
gather_indices = torch.index_select(nodes, 0, prev_indices).view(-1, self.num_trees)
features = torch.gather(x, 1, gather_indices).view(-1)
prev_indices = factor * prev_indices + torch.ge(features, torch.index_select(biases, 0, prev_indices)).long()
prev_indices = (
factor * prev_indices + self.decision_cond(features, torch.index_select(biases, 0, prev_indices)).long()
)
output = torch.index_select(self.leaf_nodes, 0, prev_indices).view(-1, self.num_trees, self.n_classes)
@ -390,20 +411,11 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
def _get_weights_and_biases(self, nodes_map, tree_depth, weight_0, weight_1, bias_0):
def depth_f_traversal(node, current_depth, node_id, leaf_start_id):
weight_0[node_id] = node.feature
bias_0[node_id] = -node.threshold
bias_0[node_id] = node.threshold
current_depth += 1
node_id += 1
if node.left.feature == -1:
node_id += 2 ** (tree_depth - current_depth - 1) - 1
v = node.left.value
weight_1[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = (
np.ones((2 ** (tree_depth - current_depth - 1), self.n_classes)) * v
)
leaf_start_id += 2 ** (tree_depth - current_depth - 1)
else:
node_id, leaf_start_id = depth_f_traversal(node.left, current_depth, node_id, leaf_start_id)
# Condition false (right sub-tree)
if node.right.feature == -1:
node_id += 2 ** (tree_depth - current_depth - 1) - 1
v = node.right.value
@ -414,6 +426,17 @@ class PerfectTreeTraversalTreeImpl(AbstractPyTorchTreeImpl):
else:
node_id, leaf_start_id = depth_f_traversal(node.right, current_depth, node_id, leaf_start_id)
# Condition true (left sub-tree)
if node.left.feature == -1:
node_id += 2 ** (tree_depth - current_depth - 1) - 1
v = node.left.value
weight_1[leaf_start_id : leaf_start_id + 2 ** (tree_depth - current_depth - 1)] = (
np.ones((2 ** (tree_depth - current_depth - 1), self.n_classes)) * v
)
leaf_start_id += 2 ** (tree_depth - current_depth - 1)
else:
node_id, leaf_start_id = depth_f_traversal(node.left, current_depth, node_id, leaf_start_id)
return node_id, leaf_start_id
depth_f_traversal(nodes_map[0], -1, 0, 0)
@ -426,14 +449,17 @@ class GEMMDecisionTreeImpl(GEMMTreeImpl):
"""
def __init__(self, logical_operator, tree_parameters, n_features, classes=None):
def __init__(self, logical_operator, tree_parameters, n_features, classes=None, extra_config={}):
"""
Args:
tree_parameters: The parameters defining the tree structure
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
extra_config: Extra configuration used to properly implement the source tree
"""
super(GEMMDecisionTreeImpl, self).__init__(logical_operator, tree_parameters, n_features, classes)
super(GEMMDecisionTreeImpl, self).__init__(
logical_operator, tree_parameters, n_features, classes, extra_config=extra_config
)
def aggregation(self, x):
output = x.sum(0).t()
@ -446,7 +472,7 @@ class TreeTraversalDecisionTreeImpl(TreeTraversalTreeImpl):
Class implementing the Tree Traversal strategy in PyTorch for decision tree models.
"""
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}):
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}, **kwargs):
"""
Args:
tree_parameters: The parameters defining the tree structure
@ -456,7 +482,7 @@ class TreeTraversalDecisionTreeImpl(TreeTraversalTreeImpl):
extra_config: Extra configuration used to properly implement the source tree
"""
super(TreeTraversalDecisionTreeImpl, self).__init__(
logical_operator, tree_parameters, max_depth, n_features, classes, extra_config=extra_config
logical_operator, tree_parameters, max_depth, n_features, classes, extra_config=extra_config, **kwargs
)
def aggregation(self, x):
@ -470,16 +496,17 @@ class PerfectTreeTraversalDecisionTreeImpl(PerfectTreeTraversalTreeImpl):
Class implementing the Perfect Tree Traversal strategy in PyTorch for decision tree models.
"""
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None):
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}, **kwargs):
"""
Args:
tree_parameters: The parameters defining the tree structure
max_depth: The maximum tree-depth in the model
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
extra_config: Extra configuration used to properly implement the source tree
"""
super(PerfectTreeTraversalDecisionTreeImpl, self).__init__(
logical_operator, tree_parameters, max_depth, n_features, classes
logical_operator, tree_parameters, max_depth, n_features, classes, extra_config=extra_config, **kwargs
)
def aggregation(self, x):
@ -494,7 +521,7 @@ class GEMMGBDTImpl(GEMMTreeImpl):
Class implementing the GEMM strategy (in PyTorch) for GBDT models.
"""
def __init__(self, logical_operator, tree_parameters, n_features, classes=None, extra_config={}):
def __init__(self, logical_operator, tree_parameters, n_features, classes=None, extra_config={}, **kwargs):
"""
Args:
tree_parameters: The parameters defining the tree structure
@ -502,7 +529,7 @@ class GEMMGBDTImpl(GEMMTreeImpl):
classes: The classes used for classification. None if implementing a regression model
extra_config: Extra configuration used to properly implement the source tree
"""
super(GEMMGBDTImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, 1, extra_config)
super(GEMMGBDTImpl, self).__init__(logical_operator, tree_parameters, n_features, classes, 1, extra_config, **kwargs)
self.n_gbdt_classes = 1
self.post_transform = _tree_commons.PostTransform()
@ -526,7 +553,7 @@ class TreeTraversalGBDTImpl(TreeTraversalTreeImpl):
Class implementing the Tree Traversal strategy in PyTorch.
"""
def __init__(self, logical_operator, tree_parameters, max_detph, n_features, classes=None, extra_config={}):
def __init__(self, logical_operator, tree_parameters, max_detph, n_features, classes=None, extra_config={}, **kwargs):
"""
Args:
tree_parameters: The parameters defining the tree structure
@ -536,7 +563,7 @@ class TreeTraversalGBDTImpl(TreeTraversalTreeImpl):
extra_config: Extra configuration used to properly implement the source tree
"""
super(TreeTraversalGBDTImpl, self).__init__(
logical_operator, tree_parameters, max_detph, n_features, classes, 1, extra_config
logical_operator, tree_parameters, max_detph, n_features, classes, 1, extra_config, **kwargs
)
self.n_gbdt_classes = 1
@ -561,7 +588,7 @@ class PerfectTreeTraversalGBDTImpl(PerfectTreeTraversalTreeImpl):
Class implementing the Perfect Tree Traversal strategy in PyTorch.
"""
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}):
def __init__(self, logical_operator, tree_parameters, max_depth, n_features, classes=None, extra_config={}, **kwargs):
"""
Args:
tree_parameters: The parameters defining the tree structure
@ -571,7 +598,7 @@ class PerfectTreeTraversalGBDTImpl(PerfectTreeTraversalTreeImpl):
extra_config: Extra configuration used to properly implement the source tree
"""
super(PerfectTreeTraversalGBDTImpl, self).__init__(
logical_operator, tree_parameters, max_depth, n_features, classes, 1, extra_config
logical_operator, tree_parameters, max_depth, n_features, classes, 1, extra_config, **kwargs
)
self.n_gbdt_classes = 1

Просмотреть файл

@ -99,7 +99,7 @@ def convert_sklearn_xgb_classifier(operator, device, extra_config):
n_classes = operator.raw_operator.n_classes_
return convert_gbdt_classifier_common(
operator, tree_infos, _get_tree_parameters, n_features, n_classes, extra_config=extra_config
operator, tree_infos, _get_tree_parameters, n_features, n_classes, decision_cond="<", extra_config=extra_config
)
@ -133,7 +133,9 @@ def convert_sklearn_xgb_regressor(operator, device, extra_config):
extra_config[constants.BASE_PREDICTION] = base_prediction
return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, extra_config=extra_config)
return convert_gbdt_common(
operator, tree_infos, _get_tree_parameters, n_features, decision_cond="<", extra_config=extra_config
)
# Register the converters.

Просмотреть файл

@ -469,6 +469,9 @@ backends = _build_backend_map()
TREE_IMPLEMENTATION = "tree_implementation"
"""Which tree implementation to use. Values can be: gemm, tree_trav, perf_tree_trav."""
TREE_OP_PRECISION_DTYPE = "tree_op_precision_dtype"
"""Which data type to be used for the threshold and leaf values of decision nodes. Values can be: float32 or float64."""
ONNX_OUTPUT_MODEL_NAME = "onnx_model_name"
"""For ONNX models we can set the name of the output model."""

Просмотреть файл

@ -7,6 +7,7 @@ import warnings
import numpy as np
import hummingbird.ml
from hummingbird.ml import constants
from hummingbird.ml._utils import lightgbm_installed, onnx_runtime_installed, tvm_installed
from tree_utils import gbdt_implementation_map
@ -362,11 +363,15 @@ class TestLGBMConverter(unittest.TestCase):
model.fit(X, y)
# Create TVM model.
tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={"tree_implementation": tree_implementation})
tvm_model = hummingbird.ml.convert(
model,
"tvm",
X,
extra_config={constants.TREE_IMPLEMENTATION: tree_implementation, constants.TREE_OP_PRECISION_DTYPE: "float64"}
)
# Check results.
np.testing.assert_allclose(tvm_model.predict(X), model.predict(X))
np.testing.assert_allclose(tvm_model.predict_proba(X), model.predict_proba(X), rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(tvm_model.predict_proba(X), model.predict_proba(X), rtol=1e-04, atol=1e-04)
if __name__ == "__main__":

Просмотреть файл

@ -15,6 +15,8 @@ from hummingbird.ml._utils import tvm_installed
from hummingbird.ml import constants
from tree_utils import dt_implementation_map
import random
class TestSklearnTreeConverter(unittest.TestCase):
# Check tree implementation
@ -725,17 +727,25 @@ class TestSklearnTreeConverter(unittest.TestCase):
for tree_method in ["gemm", "tree_trav", "perf_tree_trav"]:
for n_targets in [1, 2, 7]:
for tree_class in [DecisionTreeRegressor, ExtraTreesRegressor, RandomForestRegressor]:
model = tree_class()
seed = random.randint(0, 2**32 - 1)
if tree_method == "perf_tree_trav":
model = tree_class(random_state=seed, max_depth=10)
else:
model = tree_class(random_state=seed)
X, y = datasets.make_regression(
n_samples=100, n_features=10, n_informative=5, n_targets=n_targets, random_state=2021
n_samples=100, n_features=10, n_informative=5, n_targets=n_targets, random_state=seed
)
model.fit(X, y)
X = X.astype('float32')
y = y.astype('float32')
torch_model = hummingbird.ml.convert(
model, "torch", extra_config={constants.TREE_IMPLEMENTATION: tree_method}
model,
"torch",
extra_config={constants.TREE_IMPLEMENTATION: tree_method, constants.TREE_OP_PRECISION_DTYPE: "float64"}
)
self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-5, atol=1e-5, err_msg="{}/{}/{}/{}".format(tree_method, n_targets, tree_class, seed))
if __name__ == "__main__":

Просмотреть файл

@ -17,46 +17,53 @@ from sklearn.svm import LinearSVR
from sklearn.multioutput import MultiOutputRegressor, RegressorChain
import hummingbird.ml
from hummingbird.ml import constants
import random
random.seed(2021)
class TestSklearnMultioutputRegressor(unittest.TestCase):
# Test MultiOutputRegressor with different child learners
def test_sklearn_multioutput_regressor(self):
for n_targets in [2, 3, 4]:
for model_class in [DecisionTreeRegressor, ExtraTreesRegressor, RandomForestRegressor, LinearRegression]:
model = MultiOutputRegressor(model_class())
seed = random.randint(0, 2**32 - 1)
if model_class != LinearRegression:
model = MultiOutputRegressor(model_class(random_state=seed))
else:
model = MultiOutputRegressor(model_class())
X, y = datasets.make_regression(
n_samples=50, n_features=10, n_informative=5, n_targets=n_targets, random_state=2020
n_samples=50, n_features=10, n_informative=5, n_targets=n_targets, random_state=seed
)
X = X.astype("float32")
y = y.astype("float32")
model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch")
torch_model = hummingbird.ml.convert(model, "torch", extra_config={constants.TREE_OP_PRECISION_DTYPE: "float64"})
self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-5, atol=1e-4, err_msg="{}/{}/{}".format(n_targets, model_class, seed))
# Test RegressorChain with different child learners
def test_sklearn_regressor_chain(self):
for n_targets in [2, 3, 4]:
for model_class in [DecisionTreeRegressor, ExtraTreesRegressor, RandomForestRegressor, LinearRegression]:
seed = random.randint(0, 2**32 - 1)
order = [i for i in range(n_targets)]
random.shuffle(order)
model = RegressorChain(model_class(), order=order)
random.Random(seed).shuffle(order)
if model_class != LinearRegression:
model = RegressorChain(model_class(random_state=seed), order=order)
else:
model = RegressorChain(model_class(), order=order)
X, y = datasets.make_regression(
n_samples=50, n_features=10, n_informative=5, n_targets=n_targets, random_state=2021
n_samples=50, n_features=10, n_informative=5, n_targets=n_targets, random_state=seed
)
X = X.astype("float32")
y = y.astype("float32")
model.fit(X, y)
torch_model = hummingbird.ml.convert(model, "torch")
torch_model = hummingbird.ml.convert(model, "torch", extra_config={constants.TREE_OP_PRECISION_DTYPE: "float64"})
self.assertTrue(torch_model is not None)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(model.predict(X), torch_model.predict(X), rtol=1e-4, atol=1e-4, err_msg="{}/{}/{}".format(n_targets, model_class, seed))
if __name__ == "__main__":