Implementation of TreeEnsemble ai.onnx.ml==5 (#22333)
### Description Merges PR #21851, #21222. Implements TreeEnsemble from ai.onnx.ml==5 (CPU). --------- Co-authored-by: Bilyana Indzheva <bilyana2002@gmail.com> Co-authored-by: Bilyana Indzheva <36890669+bili2002@users.noreply.github.com> Co-authored-by: Christian Bourjau <cbourjau@users.noreply.github.com>
This commit is contained in:
Родитель
c97dd6e3c1
Коммит
a2ba3cb547
|
@ -453,6 +453,7 @@ Do not modify directly.*
|
|||
|SVMClassifier|*in* X:**T1**<br> *out* Y:**T2**<br> *out* Z:**tensor(float)**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int64), tensor(string)|
|
||||
|SVMRegressor|*in* X:**T**<br> *out* Y:**tensor(float)**|1+|**T** = tensor(float)|
|
||||
|Scaler|*in* X:**T**<br> *out* Y:**tensor(float)**|1+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|
||||
|TreeEnsemble|*in* X:**T**<br> *out* Y:**T**|5+|**T** = tensor(double), tensor(float)|
|
||||
|TreeEnsembleClassifier|*in* X:**T1**<br> *out* Y:**T2**<br> *out* Z:**tensor(float)**|3+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int64), tensor(string)|
|
||||
|||[1, 2]|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int64), tensor(string)|
|
||||
|TreeEnsembleRegressor|*in* X:**T**<br> *out* Y:**tensor(float)**|3+|**T** = tensor(double), tensor(float)|
|
||||
|
|
|
@ -2925,6 +2925,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, int32_t, TreeEnsembleClassifier);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleRegressor);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleRegressor);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, float, TreeEnsemble);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, double, TreeEnsemble);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string, LabelEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_float, LabelEncoder);
|
||||
|
@ -3043,6 +3045,10 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
TreeEnsembleRegressor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double,
|
||||
TreeEnsembleRegressor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, float,
|
||||
TreeEnsemble)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, double,
|
||||
TreeEnsemble)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string,
|
||||
LabelEncoder)>,
|
||||
|
|
|
@ -20,44 +20,48 @@ enum class OUTPUT_MODE {
|
|||
ALL_SCORES
|
||||
};
|
||||
|
||||
enum NODE_MODE : uint8_t {
|
||||
LEAF = 1,
|
||||
BRANCH_LEQ = 2,
|
||||
BRANCH_LT = 4,
|
||||
BRANCH_GTE = 6,
|
||||
BRANCH_GT = 8,
|
||||
BRANCH_EQ = 10,
|
||||
BRANCH_NEQ = 12
|
||||
enum NODE_MODE_ONNX : uint8_t {
|
||||
BRANCH_LEQ = 0,
|
||||
BRANCH_LT = 1,
|
||||
BRANCH_GTE = 2,
|
||||
BRANCH_GT = 3,
|
||||
BRANCH_EQ = 4,
|
||||
BRANCH_NEQ = 5,
|
||||
BRANCH_MEMBER = 6,
|
||||
LEAF = 7,
|
||||
};
|
||||
|
||||
static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {
|
||||
static inline NODE_MODE_ONNX MakeTreeNodeMode(const std::string& input) {
|
||||
if (input == "BRANCH_LEQ") {
|
||||
return NODE_MODE::BRANCH_LEQ;
|
||||
return NODE_MODE_ONNX::BRANCH_LEQ;
|
||||
}
|
||||
if (input == "LEAF") {
|
||||
return NODE_MODE::LEAF;
|
||||
return NODE_MODE_ONNX::LEAF;
|
||||
}
|
||||
if (input == "BRANCH_LT") {
|
||||
return NODE_MODE::BRANCH_LT;
|
||||
return NODE_MODE_ONNX::BRANCH_LT;
|
||||
}
|
||||
if (input == "BRANCH_GTE") {
|
||||
return NODE_MODE::BRANCH_GTE;
|
||||
return NODE_MODE_ONNX::BRANCH_GTE;
|
||||
}
|
||||
if (input == "BRANCH_GT") {
|
||||
return NODE_MODE::BRANCH_GT;
|
||||
return NODE_MODE_ONNX::BRANCH_GT;
|
||||
}
|
||||
if (input == "BRANCH_EQ") {
|
||||
return NODE_MODE::BRANCH_EQ;
|
||||
return NODE_MODE_ONNX::BRANCH_EQ;
|
||||
}
|
||||
return NODE_MODE::BRANCH_NEQ;
|
||||
if (input == "BRANCH_MEMBER") {
|
||||
return NODE_MODE_ONNX::BRANCH_MEMBER;
|
||||
}
|
||||
return NODE_MODE_ONNX::BRANCH_NEQ;
|
||||
}
|
||||
|
||||
enum class POST_EVAL_TRANSFORM {
|
||||
NONE,
|
||||
LOGISTIC,
|
||||
SOFTMAX,
|
||||
SOFTMAX_ZERO,
|
||||
PROBIT
|
||||
enum class POST_EVAL_TRANSFORM : int64_t {
|
||||
NONE = 0,
|
||||
LOGISTIC = 1,
|
||||
SOFTMAX = 2,
|
||||
SOFTMAX_ZERO = 3,
|
||||
PROBIT = 4
|
||||
};
|
||||
|
||||
static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) {
|
||||
|
@ -76,11 +80,11 @@ static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) {
|
|||
return POST_EVAL_TRANSFORM::PROBIT;
|
||||
}
|
||||
|
||||
enum class AGGREGATE_FUNCTION {
|
||||
AVERAGE,
|
||||
SUM,
|
||||
MIN,
|
||||
MAX
|
||||
enum class AGGREGATE_FUNCTION : int64_t {
|
||||
AVERAGE = 0,
|
||||
SUM = 1,
|
||||
MIN = 2,
|
||||
MAX = 3
|
||||
};
|
||||
|
||||
static inline AGGREGATE_FUNCTION MakeAggregateFunction(const std::string& input) {
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cpu/ml/tree_ensemble.h"
|
||||
#include "core/providers/cpu/ml/tree_ensemble_helper.h"
|
||||
#include "core/common/inlined_containers_fwd.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ml {
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
|
||||
TreeEnsemble,
|
||||
5,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).MayInplace(0, 0),
|
||||
TreeEnsemble<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
|
||||
TreeEnsemble,
|
||||
5,
|
||||
double,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()).MayInplace(0, 0),
|
||||
TreeEnsemble<double>);
|
||||
|
||||
template <typename T>
|
||||
TreeEnsemble<T>::TreeEnsemble(const OpKernelInfo& info) : OpKernel(info) {
|
||||
if constexpr (std::is_same<T, double>::value) {
|
||||
p_tree_ensemble_ = std::make_unique<detail::TreeEnsembleCommonV5<T, double>>();
|
||||
} else {
|
||||
p_tree_ensemble_ = std::make_unique<detail::TreeEnsembleCommonV5<T, float>>();
|
||||
}
|
||||
ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status TreeEnsemble<T>::GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const {
|
||||
InlinedVector<std::string> names{
|
||||
"leaf_targetids", "leaf_weights", "membership_values", "nodes_falseleafs",
|
||||
"nodes_falsenodeids", "nodes_featureids", "nodes_hitrates", "nodes_missing_value_tracks_true",
|
||||
"nodes_modes", "nodes_splits", "nodes_trueleafs", "nodes_truenodeids"};
|
||||
removable_attributes.swap(names);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
common::Status TreeEnsemble<T>::Compute(OpKernelContext* context) const {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
if (X->Shape().NumDimensions() == 0) {
|
||||
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Input shape needs to be at least a single dimension.");
|
||||
}
|
||||
int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0];
|
||||
Tensor* Y = context->Output(0, {N, p_tree_ensemble_->get_target_or_class_count()});
|
||||
return p_tree_ensemble_->compute(context, X, Y, NULL);
|
||||
}
|
||||
|
||||
} // namespace ml
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "tree_ensemble_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ml {
|
||||
template <typename T>
|
||||
class TreeEnsemble final : public OpKernel {
|
||||
typedef T InputType; // input type
|
||||
typedef float OutputType; // output type
|
||||
public:
|
||||
explicit TreeEnsemble(const OpKernelInfo& info);
|
||||
common::Status Compute(OpKernelContext* context) const override;
|
||||
Status GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const override;
|
||||
|
||||
private:
|
||||
// Pointer on one instance of
|
||||
// detail::TreeEnsembleCommonV5<T, ThresholdType>
|
||||
// where ThresholdType is defined after accessing the attributes.
|
||||
std::unique_ptr<detail::TreeEnsembleCommonAttributes> p_tree_ensemble_;
|
||||
};
|
||||
} // namespace ml
|
||||
} // namespace onnxruntime
|
|
@ -78,6 +78,40 @@ union PtrOrWeight {
|
|||
} weight_data;
|
||||
};
|
||||
|
||||
enum NODE_MODE_ORT : uint8_t {
|
||||
LEAF = 1,
|
||||
BRANCH_LEQ = 2,
|
||||
BRANCH_LT = 4,
|
||||
BRANCH_GTE = 6,
|
||||
BRANCH_GT = 8,
|
||||
BRANCH_EQ = 10,
|
||||
BRANCH_NEQ = 12,
|
||||
BRANCH_MEMBER = 14,
|
||||
};
|
||||
|
||||
inline NODE_MODE_ORT Convert_NODE_MODE_ONNX_to_ORT(NODE_MODE_ONNX node_mode) {
|
||||
switch (node_mode) {
|
||||
case NODE_MODE_ONNX::LEAF:
|
||||
return NODE_MODE_ORT::LEAF;
|
||||
case NODE_MODE_ONNX::BRANCH_LEQ:
|
||||
return NODE_MODE_ORT::BRANCH_LEQ;
|
||||
case NODE_MODE_ONNX::BRANCH_LT:
|
||||
return NODE_MODE_ORT::BRANCH_LT;
|
||||
case NODE_MODE_ONNX::BRANCH_GTE:
|
||||
return NODE_MODE_ORT::BRANCH_GTE;
|
||||
case NODE_MODE_ONNX::BRANCH_GT:
|
||||
return NODE_MODE_ORT::BRANCH_GT;
|
||||
case NODE_MODE_ONNX::BRANCH_EQ:
|
||||
return NODE_MODE_ORT::BRANCH_EQ;
|
||||
case NODE_MODE_ONNX::BRANCH_NEQ:
|
||||
return NODE_MODE_ORT::BRANCH_NEQ;
|
||||
case NODE_MODE_ONNX::BRANCH_MEMBER:
|
||||
return NODE_MODE_ORT::BRANCH_MEMBER;
|
||||
default:
|
||||
ORT_THROW("Unexpected value for node_mode");
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct TreeNodeElement {
|
||||
int feature_id;
|
||||
|
@ -98,10 +132,10 @@ struct TreeNodeElement {
|
|||
// weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also
|
||||
// stored in `value_or_unique_weight`.
|
||||
PtrOrWeight<T> truenode_or_weight;
|
||||
uint8_t flags;
|
||||
NODE_MODE_ORT flags;
|
||||
|
||||
inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); }
|
||||
inline bool is_not_leaf() const { return !(flags & NODE_MODE::LEAF); }
|
||||
inline NODE_MODE_ORT mode() const { return NODE_MODE_ORT(flags & 0xF); }
|
||||
inline bool is_not_leaf() const { return !(flags & NODE_MODE_ORT::LEAF); }
|
||||
inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; }
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,321 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/inlined_containers.h"
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "ml_common.h"
|
||||
#include "tree_ensemble_helper.h"
|
||||
#include <vector>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ml {
|
||||
namespace detail {
|
||||
|
||||
inline bool _isnan_(float x) { return std::isnan(x); }
|
||||
inline bool _isnan_(double x) { return std::isnan(x); }
|
||||
inline bool _isnan_(int64_t) { return false; }
|
||||
inline bool _isnan_(int32_t) { return false; }
|
||||
|
||||
template <typename ThresholdType>
|
||||
struct TreeEnsembleAttributesV3 {
|
||||
TreeEnsembleAttributesV3() {}
|
||||
TreeEnsembleAttributesV3(const OpKernelInfo& info, bool classifier) {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor));
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor));
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor));
|
||||
if (classifier) {
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "class_weights_as_tensor", target_class_weights_as_tensor));
|
||||
} else {
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "target_weights_as_tensor", target_class_weights_as_tensor));
|
||||
}
|
||||
#endif
|
||||
|
||||
aggregate_function = info.GetAttrOrDefault<std::string>("aggregate_function", "SUM");
|
||||
base_values = info.GetAttrsOrDefault<float>("base_values");
|
||||
nodes_falsenodeids = info.GetAttrsOrDefault<int64_t>("nodes_falsenodeids");
|
||||
nodes_featureids = info.GetAttrsOrDefault<int64_t>("nodes_featureids");
|
||||
nodes_missing_value_tracks_true = info.GetAttrsOrDefault<int64_t>("nodes_missing_value_tracks_true");
|
||||
|
||||
std::vector<std::string> nodes_modes_string = info.GetAttrsOrDefault<std::string>("nodes_modes");
|
||||
nodes_modes.reserve(nodes_modes_string.size());
|
||||
for (auto s : nodes_modes_string) {
|
||||
nodes_modes.emplace_back(MakeTreeNodeMode(s));
|
||||
}
|
||||
|
||||
nodes_nodeids = info.GetAttrsOrDefault<int64_t>("nodes_nodeids");
|
||||
nodes_treeids = info.GetAttrsOrDefault<int64_t>("nodes_treeids");
|
||||
nodes_truenodeids = info.GetAttrsOrDefault<int64_t>("nodes_truenodeids");
|
||||
nodes_values = info.GetAttrsOrDefault<float>("nodes_values");
|
||||
post_transform = info.GetAttrOrDefault<std::string>("post_transform", "NONE");
|
||||
|
||||
if (classifier) {
|
||||
target_class_ids = info.GetAttrsOrDefault<int64_t>("class_ids");
|
||||
target_class_nodeids = info.GetAttrsOrDefault<int64_t>("class_nodeids");
|
||||
target_class_treeids = info.GetAttrsOrDefault<int64_t>("class_treeids");
|
||||
target_class_weights = info.GetAttrsOrDefault<float>("class_weights");
|
||||
classlabels_strings = info.GetAttrsOrDefault<std::string>("classlabels_strings");
|
||||
classlabels_int64s = info.GetAttrsOrDefault<int64_t>("classlabels_int64s");
|
||||
n_targets_or_classes = classlabels_strings.empty() ? classlabels_int64s.size()
|
||||
: classlabels_strings.size();
|
||||
} else {
|
||||
n_targets_or_classes = info.GetAttrOrDefault<int64_t>("n_targets", 0);
|
||||
target_class_ids = info.GetAttrsOrDefault<int64_t>("target_ids");
|
||||
target_class_nodeids = info.GetAttrsOrDefault<int64_t>("target_nodeids");
|
||||
target_class_treeids = info.GetAttrsOrDefault<int64_t>("target_treeids");
|
||||
target_class_weights = info.GetAttrsOrDefault<float>("target_weights");
|
||||
|
||||
ORT_ENFORCE(n_targets_or_classes > 0);
|
||||
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size());
|
||||
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes_string.size());
|
||||
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size());
|
||||
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size());
|
||||
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size());
|
||||
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() ||
|
||||
nodes_falsenodeids.size() == nodes_values_as_tensor.size());
|
||||
ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size());
|
||||
ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size());
|
||||
ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size());
|
||||
ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty());
|
||||
ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty());
|
||||
ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty());
|
||||
ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty());
|
||||
ORT_ENFORCE(nodes_modes_string.size() < std::numeric_limits<uint32_t>::max());
|
||||
}
|
||||
}
|
||||
|
||||
std::string aggregate_function;
|
||||
std::vector<float> base_values;
|
||||
std::vector<ThresholdType> base_values_as_tensor;
|
||||
int64_t n_targets_or_classes;
|
||||
std::vector<int64_t> nodes_falsenodeids;
|
||||
std::vector<int64_t> nodes_featureids;
|
||||
std::vector<float> nodes_hitrates;
|
||||
std::vector<ThresholdType> nodes_hitrates_as_tensor;
|
||||
std::vector<int64_t> nodes_missing_value_tracks_true;
|
||||
std::vector<NODE_MODE_ONNX> nodes_modes;
|
||||
std::vector<int64_t> nodes_nodeids;
|
||||
std::vector<int64_t> nodes_treeids;
|
||||
std::vector<int64_t> nodes_truenodeids;
|
||||
std::vector<float> nodes_values;
|
||||
std::vector<ThresholdType> nodes_values_as_tensor;
|
||||
std::string post_transform;
|
||||
std::vector<int64_t> target_class_ids;
|
||||
std::vector<int64_t> target_class_nodeids;
|
||||
std::vector<int64_t> target_class_treeids;
|
||||
std::vector<float> target_class_weights;
|
||||
std::vector<ThresholdType> target_class_weights_as_tensor;
|
||||
std::vector<std::string> classlabels_strings;
|
||||
std::vector<int64_t> classlabels_int64s;
|
||||
std::vector<int64_t> class_labels;
|
||||
};
|
||||
|
||||
template <typename ThresholdType>
|
||||
struct TreeEnsembleAttributesV5 {
|
||||
TreeEnsembleAttributesV5() {}
|
||||
TreeEnsembleAttributesV5(const OpKernelInfo& info) {
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
std::vector<uint8_t> nodes_modes_i;
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "leaf_weights", leaf_weights));
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "membership_values", membership_values));
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates", nodes_hitrates));
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_modes", nodes_modes_i));
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_splits", nodes_splits));
|
||||
nodes_modes.reserve(nodes_modes.size());
|
||||
for (auto i : nodes_modes_i) {
|
||||
nodes_modes.push_back(static_cast<NODE_MODE_ONNX>(i));
|
||||
}
|
||||
#else
|
||||
// GetVectorAttrsOrDefault is not part of the minimal build.
|
||||
// As a result, TreeEnsemble v5 cannot be available in this build.
|
||||
ORT_THROW("TreeEnsemble(ai.onnx.ml==5) is not supported with the minimal build.");
|
||||
#endif
|
||||
|
||||
aggregate_function = info.GetAttrOrDefault<int64_t>("aggregate_function", 1);
|
||||
leaf_targetids = info.GetAttrsOrDefault<int64_t>("leaf_targetids");
|
||||
n_targets = info.GetAttrOrDefault<int64_t>("n_targets", 0);
|
||||
nodes_falseleafs = info.GetAttrsOrDefault<int64_t>("nodes_falseleafs");
|
||||
nodes_falsenodeids = info.GetAttrsOrDefault<int64_t>("nodes_falsenodeids");
|
||||
nodes_featureids = info.GetAttrsOrDefault<int64_t>("nodes_featureids");
|
||||
nodes_missing_value_tracks_true = info.GetAttrsOrDefault<int64_t>("nodes_missing_value_tracks_true");
|
||||
nodes_trueleafs = info.GetAttrsOrDefault<int64_t>("nodes_trueleafs");
|
||||
nodes_truenodeids = info.GetAttrsOrDefault<int64_t>("nodes_truenodeids");
|
||||
post_transform = info.GetAttrOrDefault<int64_t>("post_transform", 0);
|
||||
tree_roots = info.GetAttrsOrDefault<int64_t>("tree_roots");
|
||||
}
|
||||
|
||||
void convert_to_v3(TreeEnsembleAttributesV3<ThresholdType>& output) const {
|
||||
// Doing all transformations to get the old format.
|
||||
output.n_targets_or_classes = n_targets;
|
||||
output.aggregate_function = aggregateFunctionToString();
|
||||
output.post_transform = postTransformToString();
|
||||
std::vector<std::vector<ThresholdType>> membership_values_by_id;
|
||||
getMembershipValuesById(membership_values_by_id);
|
||||
transformInputAllTrees(output, membership_values_by_id);
|
||||
}
|
||||
|
||||
int64_t aggregate_function;
|
||||
std::vector<int64_t> leaf_targetids;
|
||||
std::vector<ThresholdType> leaf_weights;
|
||||
std::vector<ThresholdType> membership_values;
|
||||
int64_t n_targets;
|
||||
std::vector<int64_t> nodes_falseleafs;
|
||||
std::vector<int64_t> nodes_falsenodeids;
|
||||
std::vector<int64_t> nodes_featureids;
|
||||
std::vector<ThresholdType> nodes_hitrates;
|
||||
std::vector<int64_t> nodes_missing_value_tracks_true;
|
||||
std::vector<NODE_MODE_ONNX> nodes_modes;
|
||||
std::vector<ThresholdType> nodes_splits;
|
||||
std::vector<int64_t> nodes_trueleafs;
|
||||
std::vector<int64_t> nodes_truenodeids;
|
||||
int64_t post_transform;
|
||||
std::vector<int64_t> tree_roots;
|
||||
|
||||
private:
|
||||
// `membership_values` are seperated by NAN for different nodes
|
||||
// It is more convenient to preserve the values for each node in a vector
|
||||
// The vector would be empty for nodes that are not `BRANCH_MEMBER`
|
||||
void getMembershipValuesById(std::vector<std::vector<ThresholdType>>& membership_values_by_id) const {
|
||||
membership_values_by_id.clear();
|
||||
membership_values_by_id.reserve(nodes_modes.size());
|
||||
|
||||
size_t curr_id = 0;
|
||||
for (const auto node_mode : nodes_modes) {
|
||||
membership_values_by_id.emplace_back();
|
||||
if (node_mode != NODE_MODE_ONNX::BRANCH_MEMBER) {
|
||||
continue;
|
||||
}
|
||||
|
||||
while (curr_id < membership_values.size() && !_isnan_(membership_values[curr_id])) {
|
||||
membership_values_by_id.back().push_back(membership_values[curr_id++]);
|
||||
}
|
||||
curr_id++;
|
||||
}
|
||||
}
|
||||
|
||||
std::string aggregateFunctionToString() const {
|
||||
switch (aggregate_function) {
|
||||
case static_cast<int64_t>(AGGREGATE_FUNCTION::AVERAGE):
|
||||
return "AVERAGE";
|
||||
case static_cast<int64_t>(AGGREGATE_FUNCTION::SUM):
|
||||
return "SUM";
|
||||
case static_cast<int64_t>(AGGREGATE_FUNCTION::MIN):
|
||||
return "MIN";
|
||||
case static_cast<int64_t>(AGGREGATE_FUNCTION::MAX):
|
||||
return "MAX";
|
||||
default:
|
||||
ORT_THROW("Unknown value for aggregate_function.");
|
||||
}
|
||||
}
|
||||
|
||||
std::string postTransformToString() const {
|
||||
switch (post_transform) {
|
||||
case static_cast<int64_t>(POST_EVAL_TRANSFORM::NONE):
|
||||
return "NONE";
|
||||
case static_cast<int64_t>(POST_EVAL_TRANSFORM::SOFTMAX):
|
||||
return "SOFTMAX";
|
||||
case static_cast<int64_t>(POST_EVAL_TRANSFORM::LOGISTIC):
|
||||
return "LOGISTIC";
|
||||
case static_cast<int64_t>(POST_EVAL_TRANSFORM::SOFTMAX_ZERO):
|
||||
return "SOFTMAX_ZERO";
|
||||
case static_cast<int64_t>(POST_EVAL_TRANSFORM::PROBIT):
|
||||
return "PROBIT";
|
||||
default:
|
||||
ORT_THROW("Unknown value for post_transform.");
|
||||
}
|
||||
}
|
||||
|
||||
int64_t transformInputOneTree(
|
||||
const size_t curr_id, const int64_t curr_treeid, const int64_t curr_nodeid, const size_t curr_membership_value_id,
|
||||
const bool is_leaf, std::vector<std::vector<ThresholdType>>& membership_values_by_id,
|
||||
TreeEnsembleAttributesV3<ThresholdType>& output) const {
|
||||
output.nodes_nodeids.push_back(curr_nodeid);
|
||||
output.nodes_treeids.push_back(curr_treeid);
|
||||
|
||||
if (is_leaf) {
|
||||
output.nodes_modes.push_back(NODE_MODE_ONNX::LEAF);
|
||||
output.target_class_ids.push_back(leaf_targetids[curr_id]);
|
||||
output.target_class_nodeids.push_back(curr_nodeid);
|
||||
output.target_class_treeids.push_back(curr_treeid);
|
||||
output.target_class_weights_as_tensor.push_back(leaf_weights[curr_id]);
|
||||
|
||||
// the below are irrelevant for a `LEAF`
|
||||
output.nodes_featureids.push_back(0);
|
||||
output.nodes_truenodeids.push_back(0);
|
||||
output.nodes_falsenodeids.push_back(0);
|
||||
output.nodes_values_as_tensor.push_back(0);
|
||||
if (!nodes_hitrates.empty()) {
|
||||
output.nodes_hitrates.push_back(0);
|
||||
}
|
||||
if (!nodes_missing_value_tracks_true.empty()) {
|
||||
output.nodes_missing_value_tracks_true.push_back(0);
|
||||
}
|
||||
|
||||
return curr_nodeid;
|
||||
}
|
||||
|
||||
output.nodes_featureids.push_back(nodes_featureids[curr_id]);
|
||||
if (!nodes_hitrates.empty()) {
|
||||
output.nodes_hitrates_as_tensor.push_back(nodes_hitrates[curr_id]);
|
||||
}
|
||||
if (!nodes_missing_value_tracks_true.empty()) {
|
||||
output.nodes_missing_value_tracks_true.push_back(nodes_missing_value_tracks_true[curr_id]);
|
||||
}
|
||||
|
||||
// unroll `BRANCH_MEMBER` to a chain of `BRANCH_EQ`
|
||||
if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER) {
|
||||
output.nodes_modes.push_back(NODE_MODE_ONNX::BRANCH_EQ);
|
||||
output.nodes_values_as_tensor.push_back(membership_values_by_id[curr_id][curr_membership_value_id]);
|
||||
} else {
|
||||
output.nodes_modes.push_back(nodes_modes[curr_id]);
|
||||
output.nodes_values_as_tensor.push_back(nodes_splits[curr_id]);
|
||||
}
|
||||
|
||||
size_t falsenodeid_id = output.nodes_falsenodeids.size();
|
||||
output.nodes_falsenodeids.push_back(0); // change after pushing truenode subtree
|
||||
|
||||
int64_t true_nodeid = curr_nodeid + 1;
|
||||
output.nodes_truenodeids.push_back(true_nodeid);
|
||||
true_nodeid = transformInputOneTree(onnxruntime::narrow<size_t>(nodes_truenodeids[curr_id]),
|
||||
curr_treeid, true_nodeid, 0U, nodes_trueleafs[curr_id] != 0,
|
||||
membership_values_by_id, output);
|
||||
|
||||
int64_t false_nodeid = true_nodeid + 1;
|
||||
output.nodes_falsenodeids[falsenodeid_id] = false_nodeid;
|
||||
|
||||
// if node is `BRANCH_MEMBER` we are unrolling the `membership_values` for that node
|
||||
// therefore if the value is not the last, the `falsenode_id` must be pointing to the "same" node with a different membership value
|
||||
// so in that case we are only moving the pointer for `membership_values`
|
||||
//
|
||||
// otherwise, the `falsenode_id` is pointing to the real falsenode subtree
|
||||
if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER &&
|
||||
curr_membership_value_id + 1 < membership_values_by_id[curr_id].size()) {
|
||||
false_nodeid = transformInputOneTree(curr_id, curr_treeid, false_nodeid, curr_membership_value_id + 1, false,
|
||||
membership_values_by_id, output);
|
||||
} else {
|
||||
false_nodeid = transformInputOneTree(onnxruntime::narrow<size_t>(nodes_falsenodeids[curr_id]),
|
||||
curr_treeid, false_nodeid, 0U, nodes_falseleafs[curr_id] != 0,
|
||||
membership_values_by_id, output);
|
||||
}
|
||||
return false_nodeid;
|
||||
}
|
||||
|
||||
void transformInputAllTrees(TreeEnsembleAttributesV3<ThresholdType>& output,
|
||||
std::vector<std::vector<ThresholdType>>& membership_values_by_id) const {
|
||||
int64_t curr_treeid = 0;
|
||||
for (const int64_t& tree_root : tree_roots) {
|
||||
size_t tree_root_size_t = onnxruntime::narrow<size_t>(tree_root);
|
||||
transformInputOneTree(tree_root_size_t, curr_treeid, 0, 0U,
|
||||
nodes_falsenodeids[tree_root_size_t] == nodes_truenodeids[tree_root_size_t],
|
||||
membership_values_by_id, output);
|
||||
curr_treeid++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace ml
|
||||
} // namespace onnxruntime
|
|
@ -3,15 +3,21 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "tree_ensemble_aggregator.h"
|
||||
#include <mutex>
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "tree_ensemble_helper.h"
|
||||
#include "tree_ensemble_attribute.h"
|
||||
#include "tree_ensemble_aggregator.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace ml {
|
||||
namespace detail {
|
||||
|
||||
/**
|
||||
* These attributes are the kernel attributes. They are different from the onnx operator attributes
|
||||
* to improve the computation efficiency. The initialization consists in moving the onnx attributes
|
||||
* into the kernel attributes.
|
||||
*/
|
||||
class TreeEnsembleCommonAttributes {
|
||||
public:
|
||||
int64_t get_target_or_class_count() const { return this->n_targets_or_classes_; }
|
||||
|
@ -57,27 +63,7 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
|
|||
Status Init(int parallel_tree,
|
||||
int parallel_tree_N,
|
||||
int parallel_N,
|
||||
const std::string& aggregate_function,
|
||||
const std::vector<float>& base_values,
|
||||
const std::vector<ThresholdType>& base_values_as_tensor,
|
||||
int64_t n_targets_or_classes,
|
||||
const std::vector<int64_t>& nodes_falsenodeids,
|
||||
const std::vector<int64_t>& nodes_featureids,
|
||||
const std::vector<float>& nodes_hitrates,
|
||||
const std::vector<ThresholdType>& nodes_hitrates_as_tensor,
|
||||
const std::vector<int64_t>& nodes_missing_value_tracks_true,
|
||||
const std::vector<std::string>& nodes_modes,
|
||||
const std::vector<int64_t>& nodes_nodeids,
|
||||
const std::vector<int64_t>& nodes_treeids,
|
||||
const std::vector<int64_t>& nodes_truenodeids,
|
||||
const std::vector<float>& nodes_values,
|
||||
const std::vector<ThresholdType>& nodes_values_as_tensor,
|
||||
const std::string& post_transform,
|
||||
const std::vector<int64_t>& target_class_ids,
|
||||
const std::vector<int64_t>& target_class_nodeids,
|
||||
const std::vector<int64_t>& target_class_treeids,
|
||||
const std::vector<float>& target_class_weights,
|
||||
const std::vector<ThresholdType>& target_class_weights_as_tensor);
|
||||
const TreeEnsembleAttributesV3<ThresholdType>& attributes);
|
||||
|
||||
protected:
|
||||
TreeNodeElement<ThresholdType>* ProcessTreeNodeLeave(TreeNodeElement<ThresholdType>* root,
|
||||
|
@ -87,49 +73,52 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
|
|||
void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const;
|
||||
|
||||
private:
|
||||
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
|
||||
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
|
||||
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
|
||||
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,
|
||||
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids);
|
||||
bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE_ONNX>& cmodes,
|
||||
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, gsl::span<const int64_t> nodes_featureids,
|
||||
gsl::span<const ThresholdType> nodes_values_as_tensor, gsl::span<const float> node_values,
|
||||
gsl::span<const float> target_class_weights, gsl::span<const ThresholdType> target_class_weights_as_tensor,
|
||||
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices);
|
||||
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE_ONNX>& cmodes, const InlinedVector<size_t>& truenode_ids,
|
||||
const InlinedVector<size_t>& falsenode_ids, gsl::span<const int64_t> nodes_featureids,
|
||||
gsl::span<const ThresholdType> nodes_values_as_tensor, gsl::span<const float> node_values,
|
||||
gsl::span<const int64_t> nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,
|
||||
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids, gsl::span<const float> target_class_weights,
|
||||
gsl::span<const ThresholdType> target_class_weights_as_tensor, InlinedVector<std::pair<TreeNodeElementId, uint32_t>>& indices);
|
||||
};
|
||||
|
||||
// Below is simple implementation of `bit_cast` as it is supported from c++20 and the current supported version is c++17
|
||||
// Remove it when that is not the case
|
||||
template <class To, class From>
|
||||
std::enable_if_t<
|
||||
sizeof(To) == sizeof(From) &&
|
||||
std::is_trivially_copyable_v<From> &&
|
||||
std::is_trivially_copyable_v<To>,
|
||||
To>
|
||||
// constexpr support needs compiler magic
|
||||
static bit_cast(const From& src) noexcept {
|
||||
static_assert(std::is_trivially_constructible_v<To>,
|
||||
"This implementation additionally requires "
|
||||
"destination type to be trivially constructible");
|
||||
|
||||
To dst;
|
||||
std::memcpy(&dst, &src, sizeof(To));
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::conditional_t<sizeof(T) == sizeof(uint32_t), uint32_t, uint64_t> bit_cast_int(T val) {
|
||||
if constexpr (sizeof(T) == sizeof(uint32_t)) {
|
||||
return bit_cast<uint32_t>(val);
|
||||
} else if constexpr (sizeof(T) == sizeof(uint64_t)) {
|
||||
return bit_cast<uint64_t>(val);
|
||||
}
|
||||
static_assert(sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t));
|
||||
}
|
||||
|
||||
template <typename InputType, typename ThresholdType, typename OutputType>
|
||||
Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(const OpKernelInfo& info) {
|
||||
std::vector<ThresholdType> base_values_as_tensor, nodes_hitrates_as_tensor,
|
||||
nodes_values_as_tensor, target_weights_as_tensor;
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor));
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor));
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor));
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "target_weights_as_tensor", target_weights_as_tensor));
|
||||
#endif
|
||||
|
||||
return Init(
|
||||
80,
|
||||
128,
|
||||
50,
|
||||
info.GetAttrOrDefault<std::string>("aggregate_function", "SUM"),
|
||||
info.GetAttrsOrDefault<float>("base_values"),
|
||||
base_values_as_tensor,
|
||||
info.GetAttrOrDefault<int64_t>("n_targets", 0),
|
||||
info.GetAttrsOrDefault<int64_t>("nodes_falsenodeids"),
|
||||
info.GetAttrsOrDefault<int64_t>("nodes_featureids"),
|
||||
info.GetAttrsOrDefault<float>("nodes_hitrates"),
|
||||
nodes_hitrates_as_tensor,
|
||||
info.GetAttrsOrDefault<int64_t>("nodes_missing_value_tracks_true"),
|
||||
info.GetAttrsOrDefault<std::string>("nodes_modes"),
|
||||
info.GetAttrsOrDefault<int64_t>("nodes_nodeids"),
|
||||
info.GetAttrsOrDefault<int64_t>("nodes_treeids"),
|
||||
info.GetAttrsOrDefault<int64_t>("nodes_truenodeids"),
|
||||
info.GetAttrsOrDefault<float>("nodes_values"),
|
||||
nodes_values_as_tensor,
|
||||
info.GetAttrOrDefault<std::string>("post_transform", "NONE"),
|
||||
info.GetAttrsOrDefault<int64_t>("target_ids"),
|
||||
info.GetAttrsOrDefault<int64_t>("target_nodeids"),
|
||||
info.GetAttrsOrDefault<int64_t>("target_treeids"),
|
||||
info.GetAttrsOrDefault<float>("target_weights"),
|
||||
target_weights_as_tensor);
|
||||
TreeEnsembleAttributesV3<ThresholdType> attributes(info, false);
|
||||
return Init(80, 128, 50, attributes);
|
||||
}
|
||||
|
||||
template <typename InputType, typename ThresholdType, typename OutputType>
|
||||
|
@ -137,72 +126,35 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
|
|||
int parallel_tree,
|
||||
int parallel_tree_N,
|
||||
int parallel_N,
|
||||
const std::string& aggregate_function,
|
||||
const std::vector<float>& base_values,
|
||||
const std::vector<ThresholdType>& base_values_as_tensor,
|
||||
int64_t n_targets_or_classes,
|
||||
const std::vector<int64_t>& nodes_falsenodeids,
|
||||
const std::vector<int64_t>& nodes_featureids,
|
||||
const std::vector<float>& nodes_hitrates,
|
||||
const std::vector<ThresholdType>& nodes_hitrates_as_tensor,
|
||||
const std::vector<int64_t>& nodes_missing_value_tracks_true,
|
||||
const std::vector<std::string>& nodes_modes,
|
||||
const std::vector<int64_t>& nodes_nodeids,
|
||||
const std::vector<int64_t>& nodes_treeids,
|
||||
const std::vector<int64_t>& nodes_truenodeids,
|
||||
const std::vector<float>& nodes_values,
|
||||
const std::vector<ThresholdType>& nodes_values_as_tensor,
|
||||
const std::string& post_transform,
|
||||
const std::vector<int64_t>& target_class_ids,
|
||||
const std::vector<int64_t>& target_class_nodeids,
|
||||
const std::vector<int64_t>& target_class_treeids,
|
||||
const std::vector<float>& target_class_weights,
|
||||
const std::vector<ThresholdType>& target_class_weights_as_tensor) {
|
||||
const TreeEnsembleAttributesV3<ThresholdType>& attributes) {
|
||||
parallel_tree_ = parallel_tree;
|
||||
parallel_tree_N_ = parallel_tree_N;
|
||||
parallel_N_ = parallel_N;
|
||||
|
||||
ORT_ENFORCE(n_targets_or_classes > 0);
|
||||
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size());
|
||||
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes.size());
|
||||
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size());
|
||||
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size());
|
||||
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size());
|
||||
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() ||
|
||||
nodes_falsenodeids.size() == nodes_values_as_tensor.size());
|
||||
ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size());
|
||||
ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size());
|
||||
ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size());
|
||||
ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty());
|
||||
ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty());
|
||||
ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty());
|
||||
ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty());
|
||||
|
||||
aggregate_function_ = MakeAggregateFunction(aggregate_function);
|
||||
post_transform_ = MakeTransform(post_transform);
|
||||
if (!base_values_as_tensor.empty()) {
|
||||
ORT_ENFORCE(base_values.empty());
|
||||
base_values_ = base_values_as_tensor;
|
||||
aggregate_function_ = MakeAggregateFunction(attributes.aggregate_function);
|
||||
post_transform_ = MakeTransform(attributes.post_transform);
|
||||
if (!attributes.base_values_as_tensor.empty()) {
|
||||
ORT_ENFORCE(attributes.base_values.empty());
|
||||
base_values_ = attributes.base_values_as_tensor;
|
||||
} else {
|
||||
base_values_.reserve(base_values.size());
|
||||
for (size_t i = 0, limit = base_values.size(); i < limit; ++i) {
|
||||
base_values_.push_back(static_cast<ThresholdType>(base_values[i]));
|
||||
base_values_.reserve(attributes.base_values.size());
|
||||
for (size_t i = 0, limit = attributes.base_values.size(); i < limit; ++i) {
|
||||
base_values_.push_back(static_cast<ThresholdType>(attributes.base_values[i]));
|
||||
}
|
||||
}
|
||||
n_targets_or_classes_ = n_targets_or_classes;
|
||||
n_targets_or_classes_ = attributes.n_targets_or_classes;
|
||||
max_tree_depth_ = 1000;
|
||||
ORT_ENFORCE(nodes_modes.size() < std::numeric_limits<uint32_t>::max());
|
||||
|
||||
// Additional members
|
||||
size_t limit;
|
||||
uint32_t i;
|
||||
InlinedVector<NODE_MODE> cmodes;
|
||||
cmodes.reserve(nodes_modes.size());
|
||||
InlinedVector<NODE_MODE_ONNX> cmodes;
|
||||
cmodes.reserve(attributes.nodes_modes.size());
|
||||
same_mode_ = true;
|
||||
int fpos = -1;
|
||||
for (i = 0, limit = nodes_modes.size(); i < limit; ++i) {
|
||||
cmodes.push_back(MakeTreeNodeMode(nodes_modes[i]));
|
||||
if (cmodes[i] == NODE_MODE::LEAF) continue;
|
||||
for (i = 0, limit = attributes.nodes_modes.size(); i < limit; ++i) {
|
||||
cmodes.push_back(attributes.nodes_modes[i]);
|
||||
if (cmodes[i] == NODE_MODE_ONNX::LEAF) continue;
|
||||
if (fpos == -1) {
|
||||
fpos = static_cast<int>(i);
|
||||
continue;
|
||||
|
@ -210,7 +162,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
|
|||
if (cmodes[i] != cmodes[fpos]) same_mode_ = false;
|
||||
}
|
||||
|
||||
n_nodes_ = nodes_treeids.size();
|
||||
n_nodes_ = attributes.nodes_treeids.size();
|
||||
limit = static_cast<size_t>(n_nodes_);
|
||||
InlinedVector<TreeNodeElementId> node_tree_ids;
|
||||
node_tree_ids.reserve(limit);
|
||||
|
@ -227,7 +179,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
|
|||
|
||||
// Build node_tree_ids and node_tree_ids_map and truenode_ids and falsenode_ids
|
||||
for (i = 0; i < limit; ++i) {
|
||||
TreeNodeElementId node_tree_id{static_cast<int>(nodes_treeids[i]), static_cast<int>(nodes_nodeids[i])};
|
||||
TreeNodeElementId node_tree_id{static_cast<int>(attributes.nodes_treeids[i]), static_cast<int>(attributes.nodes_nodeids[i])};
|
||||
auto p = node_tree_ids_map.insert(std::pair<TreeNodeElementId, size_t>(node_tree_id, i));
|
||||
if (!p.second) {
|
||||
ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there.");
|
||||
|
@ -237,13 +189,13 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
|
|||
|
||||
TreeNodeElementId coor;
|
||||
for (i = 0; i < limit; ++i) {
|
||||
if (cmodes[i] == NODE_MODE::LEAF) {
|
||||
if (cmodes[i] == NODE_MODE_ONNX::LEAF) {
|
||||
truenode_ids.push_back(0);
|
||||
falsenode_ids.push_back(0);
|
||||
} else {
|
||||
TreeNodeElementId& node_tree_id = node_tree_ids[i];
|
||||
coor.tree_id = node_tree_id.tree_id;
|
||||
coor.node_id = static_cast<int>(nodes_truenodeids[i]);
|
||||
coor.node_id = static_cast<int>(attributes.nodes_truenodeids[i]);
|
||||
ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_));
|
||||
|
||||
auto found = node_tree_ids_map.find(coor);
|
||||
|
@ -255,7 +207,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
|
|||
}
|
||||
truenode_ids.emplace_back(found->second);
|
||||
|
||||
coor.node_id = static_cast<int>(nodes_falsenodeids[i]);
|
||||
coor.node_id = static_cast<int>(attributes.nodes_falsenodeids[i]);
|
||||
ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_));
|
||||
found = node_tree_ids_map.find(coor);
|
||||
if (found == node_tree_ids_map.end()) {
|
||||
|
@ -270,41 +222,38 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
|
|||
}
|
||||
}
|
||||
|
||||
// Sort targets
|
||||
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
|
||||
indices.reserve(attributes.target_class_nodeids.size());
|
||||
for (i = 0, limit = attributes.target_class_nodeids.size(); i < limit; i++) {
|
||||
indices.emplace_back(
|
||||
TreeNodeElementId{attributes.target_class_treeids[i], attributes.target_class_nodeids[i]}, i);
|
||||
}
|
||||
|
||||
std::sort(indices.begin(), indices.end());
|
||||
|
||||
// Let's construct nodes_ such that the false branch is always the next element in nodes_.
|
||||
// updated_mapping will translates the old position of each node to the new node position in nodes_.
|
||||
std::vector<size_t> updated_mapping(nodes_treeids.size(), 0);
|
||||
std::vector<size_t> updated_mapping(attributes.nodes_treeids.size(), 0);
|
||||
int64_t previous_tree_id = -1;
|
||||
for (i = 0; i < n_nodes_; ++i) {
|
||||
if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) {
|
||||
// New tree.
|
||||
int64_t tree_id = node_tree_ids[i].tree_id;
|
||||
size_t root_position =
|
||||
AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values,
|
||||
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
|
||||
AddNodes(i, cmodes, truenode_ids, falsenode_ids, attributes.nodes_featureids, attributes.nodes_values_as_tensor, attributes.nodes_values,
|
||||
attributes.nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
|
||||
attributes.target_class_weights, attributes.target_class_weights_as_tensor, indices);
|
||||
roots_.push_back(&nodes_[root_position]);
|
||||
previous_tree_id = tree_id;
|
||||
}
|
||||
}
|
||||
|
||||
n_trees_ = roots_.size();
|
||||
if (((int64_t)nodes_.size()) != n_nodes_) {
|
||||
ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ").");
|
||||
}
|
||||
|
||||
// Sort targets
|
||||
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
|
||||
indices.reserve(target_class_nodeids.size());
|
||||
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
|
||||
indices.emplace_back(
|
||||
std::pair<TreeNodeElementId, uint32_t>(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i));
|
||||
}
|
||||
|
||||
std::sort(indices.begin(), indices.end());
|
||||
|
||||
TreeNodeElementId ind;
|
||||
SparseValue<ThresholdType> w;
|
||||
size_t indi;
|
||||
for (indi = 0, limit = target_class_nodeids.size(); indi < limit; ++indi) {
|
||||
for (indi = 0, limit = attributes.target_class_nodeids.size(); indi < limit; ++indi) {
|
||||
ind = indices[indi].first;
|
||||
i = indices[indi].second;
|
||||
auto found = node_tree_ids_map.find(ind);
|
||||
|
@ -319,9 +268,10 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
|
|||
// ORT_THROW("Node ", ind.tree_id, "-", ind.node_id, " is not a leaf.");
|
||||
continue;
|
||||
}
|
||||
w.i = target_class_ids[i];
|
||||
w.value = target_class_weights_as_tensor.empty() ? static_cast<ThresholdType>(target_class_weights[i])
|
||||
: target_class_weights_as_tensor[i];
|
||||
w.i = attributes.target_class_ids[i];
|
||||
w.value = attributes.target_class_weights_as_tensor.empty()
|
||||
? static_cast<ThresholdType>(attributes.target_class_weights[i])
|
||||
: attributes.target_class_weights_as_tensor[i];
|
||||
if (leaf.truenode_or_weight.weight_data.n_weights == 0) {
|
||||
leaf.truenode_or_weight.weight_data.weight = static_cast<int32_t>(weights_.size());
|
||||
leaf.value_or_unique_weight = w.value;
|
||||
|
@ -331,7 +281,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
|
|||
}
|
||||
|
||||
has_missing_tracks_ = false;
|
||||
for (auto itm = nodes_missing_value_tracks_true.begin(); itm != nodes_missing_value_tracks_true.end(); ++itm) {
|
||||
for (auto itm = attributes.nodes_missing_value_tracks_true.begin(); itm != attributes.nodes_missing_value_tracks_true.end(); ++itm) {
|
||||
if (*itm) {
|
||||
has_missing_tracks_ = true;
|
||||
break;
|
||||
|
@ -341,13 +291,58 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename InputType, typename ThresholdType, typename OutputType>
|
||||
bool TreeEnsembleCommon<InputType, ThresholdType, OutputType>::CheckIfSubtreesAreEqual(
|
||||
const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE_ONNX>& cmodes,
|
||||
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, gsl::span<const int64_t> nodes_featureids,
|
||||
gsl::span<const ThresholdType> nodes_values_as_tensor, gsl::span<const float> node_values,
|
||||
gsl::span<const float> target_class_weights, gsl::span<const ThresholdType> target_class_weights_as_tensor,
|
||||
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices) {
|
||||
// Leaves have values set at 0
|
||||
if (cmodes[left_id] != cmodes[right_id] || nodes_featureids[left_id] != nodes_featureids[right_id] ||
|
||||
(!nodes_values_as_tensor.empty() && nodes_values_as_tensor[left_id] != nodes_values_as_tensor[right_id]) ||
|
||||
(nodes_values_as_tensor.empty() && node_values[left_id] != node_values[right_id])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cmodes[left_id] == NODE_MODE_ONNX::LEAF) {
|
||||
const auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[left_id], uint32_t(0)))->second;
|
||||
const auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[right_id], uint32_t(0)))->second;
|
||||
|
||||
if (target_class_weights_as_tensor.empty()) {
|
||||
return target_class_weights[left_target_node] == target_class_weights[right_target_node];
|
||||
} else {
|
||||
return target_class_weights_as_tensor[left_target_node] == target_class_weights_as_tensor[right_target_node];
|
||||
}
|
||||
}
|
||||
|
||||
return CheckIfSubtreesAreEqual(falsenode_ids[left_id], falsenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids,
|
||||
nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices) &&
|
||||
CheckIfSubtreesAreEqual(truenode_ids[left_id], truenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids,
|
||||
nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices);
|
||||
}
|
||||
|
||||
inline void UpdateThreshold(double val, double& mask) {
|
||||
uint64_t new_mask = bit_cast<uint64_t>(mask) | (1ll << (static_cast<uint32_t>(val) - 1));
|
||||
mask = bit_cast<double>(new_mask);
|
||||
}
|
||||
|
||||
inline void UpdateThreshold(float val, float& mask) {
|
||||
uint32_t new_mask = bit_cast<uint32_t>(mask) | (1 << (static_cast<uint32_t>(val) - 1));
|
||||
mask = bit_cast<float>(new_mask);
|
||||
}
|
||||
|
||||
#define BITCOUNT(T) int64_t(sizeof(T) * 8)
|
||||
#define CANMASK(v, T) (v >= 1 && v <= BITCOUNT(T)) && v == std::floor(v)
|
||||
|
||||
template <typename InputType, typename ThresholdType, typename OutputType>
|
||||
size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
|
||||
const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
|
||||
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
|
||||
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
|
||||
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id,
|
||||
const InlinedVector<TreeNodeElementId>& node_tree_ids) {
|
||||
const size_t i, const InlinedVector<NODE_MODE_ONNX>& cmodes, const InlinedVector<size_t>& truenode_ids,
|
||||
const InlinedVector<size_t>& falsenode_ids, gsl::span<const int64_t> nodes_featureids,
|
||||
gsl::span<const ThresholdType> nodes_values_as_tensor, gsl::span<const float> node_values,
|
||||
gsl::span<const int64_t> nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id,
|
||||
const InlinedVector<TreeNodeElementId>& node_tree_ids, gsl::span<const float> target_class_weights,
|
||||
gsl::span<const ThresholdType> target_class_weights_as_tensor, InlinedVector<std::pair<TreeNodeElementId, uint32_t>>& indices) {
|
||||
// Validate this index maps to the same tree_id as the one we should be building.
|
||||
if (node_tree_ids[i].tree_id != tree_id) {
|
||||
ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i);
|
||||
|
@ -364,28 +359,59 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
|
|||
updated_mapping[i] = node_pos;
|
||||
|
||||
TreeNodeElement<ThresholdType> node;
|
||||
node.flags = static_cast<uint8_t>(cmodes[i]);
|
||||
node.flags = Convert_NODE_MODE_ONNX_to_ORT(cmodes[i]);
|
||||
node.feature_id = static_cast<int>(nodes_featureids[i]);
|
||||
if (node.feature_id > max_feature_id_) {
|
||||
max_feature_id_ = node.feature_id;
|
||||
}
|
||||
node.value_or_unique_weight =
|
||||
nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
|
||||
|
||||
node.value_or_unique_weight = 0;
|
||||
const ThresholdType node_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
|
||||
if (node.flags == NODE_MODE_ORT::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) {
|
||||
UpdateThreshold(node_threshold, node.value_or_unique_weight);
|
||||
node.flags = NODE_MODE_ORT::BRANCH_MEMBER;
|
||||
} else {
|
||||
node.value_or_unique_weight = node_threshold;
|
||||
}
|
||||
|
||||
if (i < static_cast<size_t>(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) {
|
||||
node.flags |= static_cast<uint8_t>(MissingTrack::kTrue);
|
||||
node.flags = static_cast<NODE_MODE_ORT>(static_cast<uint8_t>(node.flags) | static_cast<uint8_t>(MissingTrack::kTrue));
|
||||
}
|
||||
nodes_.push_back(std::move(node));
|
||||
if (nodes_[node_pos].is_not_leaf()) {
|
||||
size_t falsenode_id = falsenode_ids[i];
|
||||
|
||||
// Categoricals are represented as a chain of `EQ` nodes where the subtree for the true child is identical for all nodes in the chain
|
||||
// Below we are folding together these nodes into one of mode `BRANCH_MEMBER`
|
||||
// The threshold of this node should be interpreted as a bitmask showing which categoricals values were found in the chain
|
||||
// Afterwards, when looking whether a feature is included we can do an `and` with the mask of the node
|
||||
// and the one of the feature (the mask has only one bit set on the place for its value)
|
||||
// Beware that if a category is bigger than the threshold type, the node stays as `EQ` and no combination is done
|
||||
if (nodes_[node_pos].flags == NODE_MODE_ORT::BRANCH_MEMBER) {
|
||||
ThresholdType falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];
|
||||
|
||||
while (cmodes[falsenode_id] == NODE_MODE_ONNX::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] &&
|
||||
CANMASK(falsenode_threshold, ThresholdType) &&
|
||||
CheckIfSubtreesAreEqual(truenode_ids[i], truenode_ids[falsenode_id], tree_id, cmodes, truenode_ids, falsenode_ids,
|
||||
nodes_featureids, nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices)) {
|
||||
UpdateThreshold(falsenode_threshold, nodes_[node_pos].value_or_unique_weight);
|
||||
falsenode_id = falsenode_ids[falsenode_id];
|
||||
falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];
|
||||
}
|
||||
}
|
||||
|
||||
size_t false_branch =
|
||||
AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
|
||||
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
|
||||
AddNodes(falsenode_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
|
||||
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
|
||||
target_class_weights, target_class_weights_as_tensor, indices);
|
||||
if (false_branch != node_pos + 1) {
|
||||
ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ",
|
||||
static_cast<int>(nodes_[node_pos].flags));
|
||||
}
|
||||
size_t true_branch =
|
||||
AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
|
||||
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
|
||||
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
|
||||
target_class_weights, target_class_weights_as_tensor, indices);
|
||||
// We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_.
|
||||
// nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch];
|
||||
nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch];
|
||||
|
@ -684,10 +710,12 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
|
|||
} \
|
||||
}
|
||||
|
||||
inline bool _isnan_(float x) { return std::isnan(x); }
|
||||
inline bool _isnan_(double x) { return std::isnan(x); }
|
||||
inline bool _isnan_(int64_t) { return false; }
|
||||
inline bool _isnan_(int32_t) { return false; }
|
||||
// Check whether the feature value is set true in the mask
|
||||
template <typename T1, typename T2>
|
||||
inline bool SetMembershipCheck(T1 val, T2 mask) {
|
||||
const int64_t val_as_int = static_cast<int64_t>(val);
|
||||
return CANMASK(val, T2) && (((1ll << (val_as_int - 1)) & bit_cast_int(mask)) != 0);
|
||||
}
|
||||
|
||||
template <typename InputType, typename ThresholdType, typename OutputType>
|
||||
TreeNodeElement<ThresholdType>*
|
||||
|
@ -696,7 +724,7 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
|
|||
InputType val;
|
||||
if (same_mode_) {
|
||||
switch (root->mode()) {
|
||||
case NODE_MODE::BRANCH_LEQ:
|
||||
case NODE_MODE_ORT::BRANCH_LEQ:
|
||||
if (has_missing_tracks_) {
|
||||
while (root->is_not_leaf()) {
|
||||
val = x_data[root->feature_id];
|
||||
|
@ -711,22 +739,36 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
|
|||
}
|
||||
}
|
||||
break;
|
||||
case NODE_MODE::BRANCH_LT:
|
||||
case NODE_MODE_ORT::BRANCH_LT:
|
||||
TREE_FIND_VALUE(<)
|
||||
break;
|
||||
case NODE_MODE::BRANCH_GTE:
|
||||
case NODE_MODE_ORT::BRANCH_GTE:
|
||||
TREE_FIND_VALUE(>=)
|
||||
break;
|
||||
case NODE_MODE::BRANCH_GT:
|
||||
case NODE_MODE_ORT::BRANCH_GT:
|
||||
TREE_FIND_VALUE(>)
|
||||
break;
|
||||
case NODE_MODE::BRANCH_EQ:
|
||||
case NODE_MODE_ORT::BRANCH_EQ:
|
||||
TREE_FIND_VALUE(==)
|
||||
break;
|
||||
case NODE_MODE::BRANCH_NEQ:
|
||||
case NODE_MODE_ORT::BRANCH_NEQ:
|
||||
TREE_FIND_VALUE(!=)
|
||||
break;
|
||||
case NODE_MODE::LEAF:
|
||||
case NODE_MODE_ORT::BRANCH_MEMBER:
|
||||
if (has_missing_tracks_) {
|
||||
while (root->is_not_leaf()) {
|
||||
val = x_data[root->feature_id];
|
||||
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
|
||||
? root->truenode_or_weight.ptr
|
||||
: root + 1;
|
||||
}
|
||||
} else {
|
||||
while (root->is_not_leaf()) {
|
||||
val = x_data[root->feature_id];
|
||||
root = SetMembershipCheck(val, root->value_or_unique_weight) ? root->truenode_or_weight.ptr : root + 1;
|
||||
}
|
||||
}
|
||||
case NODE_MODE_ORT::LEAF:
|
||||
break;
|
||||
}
|
||||
} else { // Different rules to compare to node thresholds.
|
||||
|
@ -735,31 +777,36 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
|
|||
val = x_data[root->feature_id];
|
||||
threshold = root->value_or_unique_weight;
|
||||
switch (root->mode()) {
|
||||
case NODE_MODE::BRANCH_LEQ:
|
||||
case NODE_MODE_ORT::BRANCH_LEQ:
|
||||
root = val <= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
|
||||
: root + 1;
|
||||
break;
|
||||
case NODE_MODE::BRANCH_LT:
|
||||
case NODE_MODE_ORT::BRANCH_LT:
|
||||
root = val < threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
|
||||
: root + 1;
|
||||
break;
|
||||
case NODE_MODE::BRANCH_GTE:
|
||||
case NODE_MODE_ORT::BRANCH_GTE:
|
||||
root = val >= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
|
||||
: root + 1;
|
||||
break;
|
||||
case NODE_MODE::BRANCH_GT:
|
||||
case NODE_MODE_ORT::BRANCH_GT:
|
||||
root = val > threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
|
||||
: root + 1;
|
||||
break;
|
||||
case NODE_MODE::BRANCH_EQ:
|
||||
case NODE_MODE_ORT::BRANCH_EQ:
|
||||
root = val == threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
|
||||
: root + 1;
|
||||
break;
|
||||
case NODE_MODE::BRANCH_NEQ:
|
||||
case NODE_MODE_ORT::BRANCH_NEQ:
|
||||
root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
|
||||
: root + 1;
|
||||
break;
|
||||
case NODE_MODE::LEAF:
|
||||
case NODE_MODE_ORT::BRANCH_MEMBER:
|
||||
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
|
||||
? root->truenode_or_weight.ptr
|
||||
: root + 1;
|
||||
break;
|
||||
case NODE_MODE_ORT::LEAF:
|
||||
return root;
|
||||
}
|
||||
}
|
||||
|
@ -786,67 +833,13 @@ class TreeEnsembleCommonClassifier : public TreeEnsembleCommon<InputType, Thresh
|
|||
Status Init(int parallel_tree,
|
||||
int parallel_tree_N,
|
||||
int parallel_N,
|
||||
const std::string& aggregate_function,
|
||||
const std::vector<float>& base_values,
|
||||
const std::vector<ThresholdType>& base_values_as_tensor,
|
||||
const std::vector<int64_t>& nodes_falsenodeids,
|
||||
const std::vector<int64_t>& nodes_featureids,
|
||||
const std::vector<float>& nodes_hitrates,
|
||||
const std::vector<ThresholdType>& nodes_hitrates_as_tensor,
|
||||
const std::vector<int64_t>& nodes_missing_value_tracks_true,
|
||||
const std::vector<std::string>& nodes_modes,
|
||||
const std::vector<int64_t>& nodes_nodeids,
|
||||
const std::vector<int64_t>& nodes_treeids,
|
||||
const std::vector<int64_t>& nodes_truenodeids,
|
||||
const std::vector<float>& nodes_values,
|
||||
const std::vector<ThresholdType>& nodes_values_as_tensor,
|
||||
const std::string& post_transform,
|
||||
const std::vector<int64_t>& class_ids,
|
||||
const std::vector<int64_t>& class_nodeids,
|
||||
const std::vector<int64_t>& class_treeids,
|
||||
const std::vector<float>& class_weights,
|
||||
const std::vector<ThresholdType>& class_weights_as_tensor,
|
||||
const std::vector<std::string>& classlabels_strings,
|
||||
const std::vector<int64_t>& classlabels_int64s);
|
||||
const TreeEnsembleAttributesV3<ThresholdType>& attributes);
|
||||
};
|
||||
|
||||
template <typename InputType, typename ThresholdType, typename OutputType>
|
||||
Status TreeEnsembleCommonClassifier<InputType, ThresholdType, OutputType>::Init(const OpKernelInfo& info) {
|
||||
std::vector<ThresholdType> base_values_as_tensor, nodes_hitrates_as_tensor,
|
||||
nodes_values_as_tensor, class_weights_as_tensor;
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor));
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor));
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor));
|
||||
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "class_weights_as_tensor", class_weights_as_tensor));
|
||||
#endif
|
||||
|
||||
return Init(
|
||||
80,
|
||||
128,
|
||||
50,
|
||||
info.GetAttrOrDefault<std::string>("aggregate_function", "SUM"),
|
||||
info.GetAttrsOrDefault<float>("base_values"),
|
||||
base_values_as_tensor,
|
||||
info.GetAttrsOrDefault<int64_t>("nodes_falsenodeids"),
|
||||
info.GetAttrsOrDefault<int64_t>("nodes_featureids"),
|
||||
info.GetAttrsOrDefault<float>("nodes_hitrates"),
|
||||
nodes_hitrates_as_tensor,
|
||||
info.GetAttrsOrDefault<int64_t>("nodes_missing_value_tracks_true"),
|
||||
info.GetAttrsOrDefault<std::string>("nodes_modes"),
|
||||
info.GetAttrsOrDefault<int64_t>("nodes_nodeids"),
|
||||
info.GetAttrsOrDefault<int64_t>("nodes_treeids"),
|
||||
info.GetAttrsOrDefault<int64_t>("nodes_truenodeids"),
|
||||
info.GetAttrsOrDefault<float>("nodes_values"),
|
||||
nodes_values_as_tensor,
|
||||
info.GetAttrOrDefault<std::string>("post_transform", "NONE"),
|
||||
info.GetAttrsOrDefault<int64_t>("class_ids"),
|
||||
info.GetAttrsOrDefault<int64_t>("class_nodeids"),
|
||||
info.GetAttrsOrDefault<int64_t>("class_treeids"),
|
||||
info.GetAttrsOrDefault<float>("class_weights"),
|
||||
class_weights_as_tensor,
|
||||
info.GetAttrsOrDefault<std::string>("classlabels_strings"),
|
||||
info.GetAttrsOrDefault<int64_t>("classlabels_int64s"));
|
||||
TreeEnsembleAttributesV3<ThresholdType> attributes(info, true);
|
||||
return Init(80, 128, 50, attributes);
|
||||
}
|
||||
|
||||
template <typename InputType, typename ThresholdType, typename OutputType>
|
||||
|
@ -854,65 +847,20 @@ Status TreeEnsembleCommonClassifier<InputType, ThresholdType, OutputType>::Init(
|
|||
int parallel_tree,
|
||||
int parallel_tree_N,
|
||||
int parallel_N,
|
||||
const std::string& aggregate_function,
|
||||
const std::vector<float>& base_values,
|
||||
const std::vector<ThresholdType>& base_values_as_tensor,
|
||||
const std::vector<int64_t>& nodes_falsenodeids,
|
||||
const std::vector<int64_t>& nodes_featureids,
|
||||
const std::vector<float>& nodes_hitrates,
|
||||
const std::vector<ThresholdType>& nodes_hitrates_as_tensor,
|
||||
const std::vector<int64_t>& nodes_missing_value_tracks_true,
|
||||
const std::vector<std::string>& nodes_modes,
|
||||
const std::vector<int64_t>& nodes_nodeids,
|
||||
const std::vector<int64_t>& nodes_treeids,
|
||||
const std::vector<int64_t>& nodes_truenodeids,
|
||||
const std::vector<float>& nodes_values,
|
||||
const std::vector<ThresholdType>& nodes_values_as_tensor,
|
||||
const std::string& post_transform,
|
||||
const std::vector<int64_t>& class_ids,
|
||||
const std::vector<int64_t>& class_nodeids,
|
||||
const std::vector<int64_t>& class_treeids,
|
||||
const std::vector<float>& class_weights,
|
||||
const std::vector<ThresholdType>& class_weights_as_tensor,
|
||||
const std::vector<std::string>& classlabels_strings,
|
||||
const std::vector<int64_t>& classlabels_int64s) {
|
||||
auto status = TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
|
||||
parallel_tree,
|
||||
parallel_tree_N,
|
||||
parallel_N,
|
||||
aggregate_function,
|
||||
base_values,
|
||||
base_values_as_tensor,
|
||||
classlabels_strings.empty() ? classlabels_int64s.size()
|
||||
: classlabels_strings.size(),
|
||||
nodes_falsenodeids,
|
||||
nodes_featureids,
|
||||
nodes_hitrates,
|
||||
nodes_hitrates_as_tensor,
|
||||
nodes_missing_value_tracks_true,
|
||||
nodes_modes,
|
||||
nodes_nodeids,
|
||||
nodes_treeids,
|
||||
nodes_truenodeids,
|
||||
nodes_values,
|
||||
nodes_values_as_tensor,
|
||||
post_transform,
|
||||
class_ids,
|
||||
class_nodeids,
|
||||
class_treeids,
|
||||
class_weights,
|
||||
class_weights_as_tensor);
|
||||
const TreeEnsembleAttributesV3<ThresholdType>& attributes) {
|
||||
auto status = TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(parallel_tree, parallel_tree_N, parallel_N, attributes);
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
classlabels_strings_ = classlabels_strings;
|
||||
classlabels_int64s_ = classlabels_int64s;
|
||||
classlabels_strings_ = attributes.classlabels_strings;
|
||||
classlabels_int64s_ = attributes.classlabels_int64s;
|
||||
|
||||
InlinedHashSet<int64_t> weights_classes;
|
||||
weights_classes.reserve(class_ids.size());
|
||||
weights_classes.reserve(attributes.target_class_ids.size());
|
||||
weights_are_all_positive_ = true;
|
||||
for (size_t i = 0, end = class_ids.size(); i < end; ++i) {
|
||||
weights_classes.insert(class_ids[i]);
|
||||
if (weights_are_all_positive_ && (!class_weights.empty() ? class_weights[i] : class_weights_as_tensor[i]) < 0)
|
||||
for (size_t i = 0, end = attributes.target_class_ids.size(); i < end; ++i) {
|
||||
weights_classes.insert(attributes.target_class_ids[i]);
|
||||
if (weights_are_all_positive_ && (!attributes.target_class_weights.empty() ? attributes.target_class_weights[i]
|
||||
: attributes.target_class_weights_as_tensor[i]) < 0)
|
||||
weights_are_all_positive_ = false;
|
||||
}
|
||||
binary_case_ = this->n_targets_or_classes_ == 2 && weights_classes.size() == 1;
|
||||
|
@ -957,6 +905,43 @@ Status TreeEnsembleCommonClassifier<InputType, ThresholdType, OutputType>::compu
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename IOType, typename ThresholdType>
|
||||
class TreeEnsembleCommonV5 : public TreeEnsembleCommon<IOType, ThresholdType, IOType> {
|
||||
public:
|
||||
virtual Status Init(const OpKernelInfo& info);
|
||||
|
||||
Status Init(int parallel_tree,
|
||||
int parallel_tree_N,
|
||||
int parallel_N,
|
||||
const TreeEnsembleAttributesV5<ThresholdType>& attributes);
|
||||
};
|
||||
|
||||
template <typename IOType, typename ThresholdType>
|
||||
Status TreeEnsembleCommonV5<IOType, ThresholdType>::Init(const OpKernelInfo& info) {
|
||||
TreeEnsembleAttributesV5<ThresholdType> attributes(info);
|
||||
return Init(80, 128, 50, attributes);
|
||||
}
|
||||
|
||||
template <typename IOType, typename ThresholdType>
|
||||
Status TreeEnsembleCommonV5<IOType, ThresholdType>::Init(
|
||||
int parallel_tree,
|
||||
int parallel_tree_N,
|
||||
int parallel_N,
|
||||
const TreeEnsembleAttributesV5<ThresholdType>& attributes) {
|
||||
TreeEnsembleAttributesV3<ThresholdType> attributes_v3;
|
||||
attributes.convert_to_v3(attributes_v3);
|
||||
|
||||
attributes_v3.base_values.clear();
|
||||
attributes_v3.base_values_as_tensor.clear();
|
||||
attributes_v3.nodes_hitrates.clear();
|
||||
attributes_v3.nodes_values.clear();
|
||||
attributes_v3.target_class_weights.clear();
|
||||
|
||||
auto status = TreeEnsembleCommon<IOType, ThresholdType, IOType>::Init(parallel_tree, parallel_tree_N, parallel_N, attributes_v3);
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace ml
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -5,63 +5,53 @@
|
|||
|
||||
#include "core/providers/cpu/ml/tree_ensemble_helper.h"
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "onnx/defs/tensor_proto_util.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace std;
|
||||
namespace onnxruntime {
|
||||
namespace ml {
|
||||
|
||||
Status GetNumberOfElementsAttrsOrDefault(const OpKernelInfo& info, const std::string& name,
|
||||
ONNX_NAMESPACE::TensorProto_DataType proto_type,
|
||||
size_t& n_elements, ONNX_NAMESPACE::TensorProto& proto) {
|
||||
auto status = info.GetAttr(name, &proto);
|
||||
if (!status.IsOK()) {
|
||||
// Attribute is missing, n_elements is set to 0.
|
||||
n_elements = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
auto n_dims = proto.dims_size();
|
||||
if (n_dims == 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attribute:'", name, "' is specified but is empty.");
|
||||
}
|
||||
ORT_ENFORCE(n_dims == 1, "Attribute '", name, "' must be a vector.");
|
||||
ORT_ENFORCE(proto.data_type() == proto_type,
|
||||
"Unexpected type (", proto.data_type(), "(for attribute '", name, "'.");
|
||||
|
||||
n_elements = onnxruntime::narrow<size_t>(proto.dims()[0]);
|
||||
ORT_ENFORCE(n_elements > 0, "Attribute '", name, "' has one dimension but is empty.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename TH>
|
||||
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name,
|
||||
ONNX_NAMESPACE::TensorProto_DataType proto_type, std::vector<TH>& data) {
|
||||
if (proto_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE) {
|
||||
ORT_ENFORCE((std::is_same<double, TH>::value));
|
||||
} else if (proto_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) {
|
||||
ORT_ENFORCE((std::is_same<float, TH>::value));
|
||||
} else {
|
||||
ORT_NOT_IMPLEMENTED("GetVectorAttrsOrDefault not implemented for type ", proto_type);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status GetAnyVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<T>& data) {
|
||||
ONNX_NAMESPACE::TensorProto proto;
|
||||
size_t n_elements;
|
||||
data.clear();
|
||||
ORT_THROW_IF_ERROR(GetNumberOfElementsAttrsOrDefault(info, name, proto_type, n_elements, proto));
|
||||
if (n_elements == 0) {
|
||||
auto result = info.GetAttr(name, &proto);
|
||||
|
||||
SafeInt<int64_t> n_elements(1);
|
||||
for (auto dim : proto.dims()) {
|
||||
n_elements *= dim;
|
||||
}
|
||||
|
||||
if (proto.dims().empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
data = ONNX_NAMESPACE::ParseData<TH>(&proto);
|
||||
|
||||
const SafeInt<size_t> tensor_size(n_elements);
|
||||
data.clear();
|
||||
data.resize(tensor_size);
|
||||
|
||||
result = utils::UnpackTensor<T>(proto, std::filesystem::path(), data.data(), tensor_size);
|
||||
ORT_ENFORCE(result.IsOK(), "TreeEnsemble could not unpack tensor attribute ", name);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<double>& data) {
|
||||
return GetVectorAttrsOrDefault(info, name, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE, data);
|
||||
return GetAnyVectorAttrsOrDefault(info, name, data);
|
||||
}
|
||||
|
||||
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<float>& data) {
|
||||
return GetVectorAttrsOrDefault(info, name, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, data);
|
||||
return GetAnyVectorAttrsOrDefault(info, name, data);
|
||||
}
|
||||
|
||||
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<MLFloat16>& data) {
|
||||
return GetAnyVectorAttrsOrDefault(info, name, data);
|
||||
}
|
||||
|
||||
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<uint8_t>& data) {
|
||||
return GetAnyVectorAttrsOrDefault(info, name, data);
|
||||
}
|
||||
|
||||
} // namespace ml
|
||||
|
|
|
@ -13,6 +13,8 @@ namespace ml {
|
|||
|
||||
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<double>& data);
|
||||
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<float>& data);
|
||||
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<MLFloat16>& data);
|
||||
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<uint8_t>& data);
|
||||
|
||||
} // namespace ml
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -132,6 +132,7 @@ DEFAULT_OP_BLOCK_LIST = [
|
|||
"Scaler",
|
||||
"TreeEnsembleClassifier",
|
||||
"TreeEnsembleRegressor",
|
||||
"TreeEnsemble",
|
||||
"ZipMap",
|
||||
"NonMaxSuppression",
|
||||
"TopK",
|
||||
|
|
|
@ -0,0 +1,294 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
static ONNX_NAMESPACE::TensorProto make_tensor(std::vector<double> array, std::string name) {
|
||||
ONNX_NAMESPACE::TensorProto array_as_tensor;
|
||||
array_as_tensor.set_name(name);
|
||||
array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE);
|
||||
array_as_tensor.add_dims(array.size());
|
||||
for (auto v : array) {
|
||||
array_as_tensor.add_double_data(v);
|
||||
}
|
||||
|
||||
return array_as_tensor;
|
||||
}
|
||||
|
||||
static ONNX_NAMESPACE::TensorProto make_tensor(std::vector<float> array, std::string name) {
|
||||
ONNX_NAMESPACE::TensorProto array_as_tensor;
|
||||
array_as_tensor.set_name(name);
|
||||
array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
array_as_tensor.add_dims(array.size());
|
||||
for (auto v : array) {
|
||||
array_as_tensor.add_float_data(v);
|
||||
}
|
||||
|
||||
return array_as_tensor;
|
||||
}
|
||||
|
||||
static ONNX_NAMESPACE::TensorProto make_tensor(std::vector<uint8_t> array, std::string name) {
|
||||
ONNX_NAMESPACE::TensorProto array_as_tensor;
|
||||
array_as_tensor.set_name(name);
|
||||
array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
|
||||
array_as_tensor.add_dims(array.size());
|
||||
for (const auto v : array) {
|
||||
array_as_tensor.add_int32_data(v);
|
||||
}
|
||||
|
||||
return array_as_tensor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _multiply_update_array(std::vector<T>& data, int n, T inc = 0) {
|
||||
std::vector<T> copy = data;
|
||||
data.resize(copy.size() * n);
|
||||
T cst = 0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
for (size_t j = 0; j < copy.size(); ++j) {
|
||||
data[j + i * copy.size()] = copy[j] + cst;
|
||||
}
|
||||
cst += inc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _multiply_update_childnode(std::vector<T>& childnodes, std::vector<T>& childleafs, std::vector<T>& otherchildleafs, int n) {
|
||||
int64_t leafs_cnt = 0;
|
||||
int64_t nodes_cnt = childnodes.size();
|
||||
for (auto& childleaf : childleafs) {
|
||||
if (childleaf) {
|
||||
leafs_cnt++;
|
||||
}
|
||||
}
|
||||
for (auto& childleaf : otherchildleafs) {
|
||||
if (childleaf) {
|
||||
leafs_cnt++;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<T> copy = childnodes;
|
||||
childnodes.resize(copy.size() * n);
|
||||
T leafs_cst = 0;
|
||||
T nodes_cst = 0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
for (size_t j = 0; j < copy.size(); ++j) {
|
||||
T curr_inc = childleafs[j] ? leafs_cst : nodes_cst;
|
||||
childnodes[j + i * copy.size()] = copy[j] + curr_inc;
|
||||
}
|
||||
|
||||
leafs_cst += leafs_cnt;
|
||||
nodes_cst += nodes_cnt;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _multiply_arrays_values(std::vector<T>& data, int64_t val) {
|
||||
for (auto& curr : data) {
|
||||
curr *= val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GenTreeAndRunTest(const std::vector<T>& X, const std::vector<T>& Y, const int64_t& aggregate_function, int n_trees = 1) {
|
||||
OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain);
|
||||
int64_t n_targets = 2;
|
||||
|
||||
int64_t post_transform = 0;
|
||||
std::vector<int64_t> tree_roots = {0};
|
||||
std::vector<int64_t> nodes_featureids = {0, 0, 0};
|
||||
std::vector<uint8_t> nodes_modes = {0, 0, 0};
|
||||
std::vector<T> nodes_splits = {3.14f, 1.2f, 4.2f};
|
||||
std::vector<int64_t> nodes_truenodeids = {1, 0, 1};
|
||||
std::vector<int64_t> nodes_trueleafs = {0, 1, 1};
|
||||
std::vector<int64_t> nodes_falsenodeids = {2, 2, 3};
|
||||
std::vector<int64_t> nodes_falseleafs = {0, 1, 1};
|
||||
|
||||
std::vector<int64_t> leaf_targetids = {0, 1, 0, 1};
|
||||
std::vector<T> leaf_weights = {5.23f, 12.12f, -12.23f, 7.21f};
|
||||
|
||||
if (n_trees > 1) {
|
||||
// Multiplies the number of trees to test the parallelization by trees.
|
||||
_multiply_update_array(tree_roots, n_trees, (int64_t)nodes_truenodeids.size());
|
||||
_multiply_update_array(nodes_featureids, n_trees);
|
||||
_multiply_update_childnode(nodes_truenodeids, nodes_trueleafs, nodes_falseleafs, n_trees);
|
||||
_multiply_update_childnode(nodes_falsenodeids, nodes_falseleafs, nodes_trueleafs, n_trees);
|
||||
_multiply_update_array(nodes_trueleafs, n_trees);
|
||||
_multiply_update_array(nodes_falseleafs, n_trees);
|
||||
_multiply_update_array(leaf_targetids, n_trees);
|
||||
_multiply_update_array(nodes_modes, n_trees);
|
||||
_multiply_update_array(nodes_splits, n_trees);
|
||||
_multiply_update_array(leaf_weights, n_trees);
|
||||
}
|
||||
|
||||
auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes");
|
||||
auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits");
|
||||
auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight");
|
||||
|
||||
// add attributes
|
||||
test.AddAttribute("n_targets", n_targets);
|
||||
test.AddAttribute("aggregate_function", aggregate_function);
|
||||
test.AddAttribute("post_transform", post_transform);
|
||||
test.AddAttribute("tree_roots", tree_roots);
|
||||
test.AddAttribute("nodes_modes", nodes_modes_as_tensor);
|
||||
test.AddAttribute("nodes_featureids", nodes_featureids);
|
||||
test.AddAttribute("nodes_splits", nodes_splits_as_tensor);
|
||||
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
|
||||
test.AddAttribute("nodes_trueleafs", nodes_trueleafs);
|
||||
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
|
||||
test.AddAttribute("nodes_falseleafs", nodes_falseleafs);
|
||||
test.AddAttribute("leaf_targetids", leaf_targetids);
|
||||
test.AddAttribute("leaf_weights", leaf_weights_as_tensor);
|
||||
|
||||
// fill input data
|
||||
test.AddInput<T>("X", {3, 2}, X);
|
||||
test.AddOutput<T>("Y", {3, 2}, Y);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GenTreeAndRunTestWithSetMembership(const std::vector<T>& X, const std::vector<T>& Y, const int64_t& aggregate_function, int n_trees = 1) {
|
||||
OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain);
|
||||
int64_t n_targets = 4;
|
||||
|
||||
int64_t post_transform = 0;
|
||||
std::vector<int64_t> tree_roots = {0};
|
||||
std::vector<int64_t> nodes_featureids = {0, 0, 0};
|
||||
std::vector<int64_t> nodes_truenodeids = {1, 0, 1};
|
||||
std::vector<int64_t> nodes_trueleafs = {0, 1, 1};
|
||||
std::vector<int64_t> nodes_falsenodeids = {2, 2, 3};
|
||||
std::vector<int64_t> nodes_falseleafs = {1, 0, 1};
|
||||
std::vector<int64_t> leaf_targetids = {0, 1, 2, 3};
|
||||
|
||||
std::vector<uint8_t> nodes_modes = {0, 6, 6};
|
||||
std::vector<T> nodes_splits = {11.f, 232344.f, NAN};
|
||||
std::vector<T> membership_values = {1.2f, 3.7f, 8.f, 9.f, NAN, 12.f, 7.f, NAN};
|
||||
std::vector<T> leaf_weights = {1.f, 10.f, 1000.f, 100.f};
|
||||
|
||||
if (n_trees > 1) {
|
||||
// Multiplies the number of trees to test the parallelization by trees.
|
||||
_multiply_update_array(tree_roots, n_trees, (int64_t)nodes_truenodeids.size());
|
||||
_multiply_update_array(nodes_featureids, n_trees);
|
||||
_multiply_update_childnode(nodes_truenodeids, nodes_trueleafs, nodes_falseleafs, n_trees);
|
||||
_multiply_update_childnode(nodes_falsenodeids, nodes_falseleafs, nodes_trueleafs, n_trees);
|
||||
_multiply_update_array(nodes_trueleafs, n_trees);
|
||||
_multiply_update_array(nodes_falseleafs, n_trees);
|
||||
_multiply_update_array(leaf_targetids, n_trees);
|
||||
_multiply_update_array(nodes_modes, n_trees);
|
||||
_multiply_update_array(nodes_splits, n_trees);
|
||||
_multiply_update_array(membership_values, n_trees);
|
||||
_multiply_update_array(leaf_weights, n_trees);
|
||||
}
|
||||
|
||||
auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes");
|
||||
auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits");
|
||||
auto membership_values_as_tensor = make_tensor(membership_values, "membership_values");
|
||||
auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight");
|
||||
|
||||
// add attributes
|
||||
test.AddAttribute("n_targets", n_targets);
|
||||
test.AddAttribute("aggregate_function", aggregate_function);
|
||||
test.AddAttribute("post_transform", post_transform);
|
||||
test.AddAttribute("tree_roots", tree_roots);
|
||||
test.AddAttribute("nodes_modes", nodes_modes_as_tensor);
|
||||
test.AddAttribute("nodes_featureids", nodes_featureids);
|
||||
test.AddAttribute("nodes_splits", nodes_splits_as_tensor);
|
||||
test.AddAttribute("membership_values", membership_values_as_tensor);
|
||||
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
|
||||
test.AddAttribute("nodes_trueleafs", nodes_trueleafs);
|
||||
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
|
||||
test.AddAttribute("nodes_falseleafs", nodes_falseleafs);
|
||||
test.AddAttribute("leaf_targetids", leaf_targetids);
|
||||
test.AddAttribute("leaf_weights", leaf_weights_as_tensor);
|
||||
|
||||
// fill input data
|
||||
test.AddInput<T>("X", {6, 1}, X);
|
||||
test.AddOutput<T>("Y", {6, 4}, Y);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MLOpTest, TreeEnsembleFloat) {
|
||||
std::vector<float> X = {1.2f, 3.4f, -0.12f, 1.66f, 4.14f, 1.77f};
|
||||
std::vector<float> Y = {5.23f, 0.f, 5.23f, 0.f, 0.f, 12.12f};
|
||||
GenTreeAndRunTest<float>(X, Y, 1, 1);
|
||||
|
||||
Y = {15.69f, 0.f, 15.69f, 0.f, 0.f, 36.36f};
|
||||
GenTreeAndRunTest<float>(X, Y, 1, 3);
|
||||
}
|
||||
|
||||
TEST(MLOpTest, TreeEnsembleDouble) {
|
||||
std::vector<double> X = {1.2f, 3.4f, -0.12f, 1.66f, 4.14f, 1.77f};
|
||||
std::vector<double> Y = {5.23f, 0.f, 5.23f, 0.f, 0.f, 12.12f};
|
||||
GenTreeAndRunTest<double>(X, Y, 1, 1);
|
||||
|
||||
_multiply_arrays_values(Y, 3);
|
||||
GenTreeAndRunTest<double>(X, Y, 1, 3);
|
||||
}
|
||||
|
||||
TEST(MLOpTest, TreeEnsembleSetMembership) {
|
||||
std::vector<double> X = {1.2f, 3.4f, -0.12f, NAN, 12.0f, 7.0f};
|
||||
std::vector<double> Y = {
|
||||
1.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 0.f, 100.f,
|
||||
0.f, 0.f, 0.f, 100.f,
|
||||
0.f, 0.f, 1000.f, 0.f,
|
||||
0.f, 0.f, 1000.f, 0.f,
|
||||
0.f, 10.f, 0.f, 0.f};
|
||||
GenTreeAndRunTestWithSetMembership<double>(X, Y, 1, 1);
|
||||
|
||||
_multiply_arrays_values(Y, 5);
|
||||
GenTreeAndRunTestWithSetMembership<double>(X, Y, 1, 5);
|
||||
}
|
||||
|
||||
TEST(MLOpTest, TreeEnsembleLeafOnly) {
|
||||
OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain);
|
||||
int64_t n_targets = 1;
|
||||
|
||||
int64_t aggregate_function = 1;
|
||||
int64_t post_transform = 0;
|
||||
std::vector<int64_t> tree_roots = {0};
|
||||
std::vector<uint8_t> nodes_modes = {0};
|
||||
std::vector<int64_t> nodes_featureids = {0};
|
||||
std::vector<double> nodes_splits = {0.f};
|
||||
std::vector<int64_t> nodes_truenodeids = {0};
|
||||
std::vector<int64_t> nodes_trueleafs = {1};
|
||||
std::vector<int64_t> nodes_falsenodeids = {0};
|
||||
std::vector<int64_t> nodes_falseleafs = {1};
|
||||
|
||||
std::vector<int64_t> leaf_targetids = {0};
|
||||
std::vector<double> leaf_weights = {6.23f};
|
||||
|
||||
auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes");
|
||||
auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits");
|
||||
auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight");
|
||||
|
||||
// add attributes
|
||||
test.AddAttribute("n_targets", n_targets);
|
||||
test.AddAttribute("aggregate_function", aggregate_function);
|
||||
test.AddAttribute("post_transform", post_transform);
|
||||
test.AddAttribute("tree_roots", tree_roots);
|
||||
test.AddAttribute("nodes_modes", nodes_modes_as_tensor);
|
||||
test.AddAttribute("nodes_featureids", nodes_featureids);
|
||||
test.AddAttribute("nodes_splits", nodes_splits_as_tensor);
|
||||
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
|
||||
test.AddAttribute("nodes_trueleafs", nodes_trueleafs);
|
||||
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
|
||||
test.AddAttribute("nodes_falseleafs", nodes_falseleafs);
|
||||
test.AddAttribute("leaf_targetids", leaf_targetids);
|
||||
test.AddAttribute("leaf_weights", leaf_weights_as_tensor);
|
||||
|
||||
// fill input data
|
||||
std::vector<double> X = {1.f, 4.f};
|
||||
std::vector<double> Y = {6.23f, 6.23f};
|
||||
|
||||
test.AddInput<double>("X", {2, 1}, X);
|
||||
test.AddOutput<double>("Y", {2, 1}, Y);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
|
@ -679,6 +679,90 @@ TEST(MLOpTest, TreeRegressorSingleTargetSum_as_tensor_precision) {
|
|||
GenTreeAndRunTest1_as_tensor_precision(3);
|
||||
}
|
||||
|
||||
TEST(MLOpTest, TreeRegressorCategoricals) {
|
||||
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);
|
||||
|
||||
// tree
|
||||
int64_t n_targets = 1;
|
||||
std::vector<int64_t> nodes_featureids = {0, 0, 0, 0, 1, 0, 0};
|
||||
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"};
|
||||
std::vector<float> nodes_values = {1, 3, 4, 0, 5.5, 0, 0};
|
||||
|
||||
std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
|
||||
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
|
||||
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0};
|
||||
std::vector<int64_t> nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0};
|
||||
|
||||
std::string post_transform = "NONE";
|
||||
std::vector<int64_t> target_ids = {0, 0, 0};
|
||||
std::vector<int64_t> target_nodeids = {3, 5, 6};
|
||||
std::vector<int64_t> target_treeids = {0, 0, 0};
|
||||
std::vector<float> target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727};
|
||||
|
||||
// add attributes
|
||||
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
|
||||
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
|
||||
test.AddAttribute("nodes_treeids", nodes_treeids);
|
||||
test.AddAttribute("nodes_nodeids", nodes_nodeids);
|
||||
test.AddAttribute("nodes_featureids", nodes_featureids);
|
||||
test.AddAttribute("nodes_values", nodes_values);
|
||||
test.AddAttribute("nodes_modes", nodes_modes);
|
||||
test.AddAttribute("target_treeids", target_treeids);
|
||||
test.AddAttribute("target_nodeids", target_nodeids);
|
||||
test.AddAttribute("target_ids", target_ids);
|
||||
test.AddAttribute("target_weights", target_weights);
|
||||
test.AddAttribute("n_targets", n_targets);
|
||||
|
||||
// fill input data
|
||||
std::vector<float> X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f};
|
||||
std::vector<float> Y = {17.700000762939453, 11.100000381469727, -4.699999809265137};
|
||||
test.AddInput<float>("X", {3, 2}, X);
|
||||
test.AddOutput<float>("Y", {3, 1}, Y);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MLOpTest, TreeRegressorCategoricalsFolding) {
|
||||
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);
|
||||
|
||||
// tree
|
||||
int64_t n_targets = 1;
|
||||
std::vector<int64_t> nodes_featureids = {0, 0, 1, 1, 0, 0, 0};
|
||||
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "LEAF", "LEAF"};
|
||||
std::vector<float> nodes_values = {1, 3, 2, 3, 0, 0, 0};
|
||||
|
||||
std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
|
||||
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
|
||||
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 4, 0, 0, 0};
|
||||
std::vector<int64_t> nodes_truenodeids = {5, 5, 6, 6, 0, 0, 0};
|
||||
|
||||
std::string post_transform = "NONE";
|
||||
std::vector<int64_t> target_ids = {0, 0, 0};
|
||||
std::vector<int64_t> target_nodeids = {4, 5, 6};
|
||||
std::vector<int64_t> target_treeids = {0, 0, 0};
|
||||
std::vector<float> target_weights = {17.700000762939453, 11.100000381469727, -4.699999809265137};
|
||||
|
||||
// add attributes
|
||||
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
|
||||
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
|
||||
test.AddAttribute("nodes_treeids", nodes_treeids);
|
||||
test.AddAttribute("nodes_nodeids", nodes_nodeids);
|
||||
test.AddAttribute("nodes_featureids", nodes_featureids);
|
||||
test.AddAttribute("nodes_values", nodes_values);
|
||||
test.AddAttribute("nodes_modes", nodes_modes);
|
||||
test.AddAttribute("target_treeids", target_treeids);
|
||||
test.AddAttribute("target_nodeids", target_nodeids);
|
||||
test.AddAttribute("target_ids", target_ids);
|
||||
test.AddAttribute("target_weights", target_weights);
|
||||
test.AddAttribute("n_targets", n_targets);
|
||||
|
||||
// fill input data
|
||||
std::vector<float> X = {1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f};
|
||||
std::vector<float> Y = {11.100000381469727, 11.100000381469727, -4.699999809265137, 17.700000762939453};
|
||||
test.AddInput<float>("X", {4, 2}, X);
|
||||
test.AddOutput<float>("Y", {4, 1}, Y);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(MLOpTest, TreeRegressorTrueNodeBeforeNode) {
|
||||
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче