Implementation of TreeEnsemble ai.onnx.ml==5 (#22333)

### Description
Merges PR #21851, #21222.

Implements TreeEnsemble from ai.onnx.ml==5 (CPU).

---------

Co-authored-by: Bilyana Indzheva <bilyana2002@gmail.com>
Co-authored-by: Bilyana Indzheva <36890669+bili2002@users.noreply.github.com>
Co-authored-by: Christian Bourjau <cbourjau@users.noreply.github.com>
This commit is contained in:
Xavier Dupré 2024-11-22 19:48:23 +01:00 коммит произвёл GitHub
Родитель c97dd6e3c1
Коммит a2ba3cb547
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
13 изменённых файлов: 1157 добавлений и 351 удалений

Просмотреть файл

@ -453,6 +453,7 @@ Do not modify directly.*
|SVMClassifier|*in* X:**T1**<br> *out* Y:**T2**<br> *out* Z:**tensor(float)**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int64), tensor(string)|
|SVMRegressor|*in* X:**T**<br> *out* Y:**tensor(float)**|1+|**T** = tensor(float)|
|Scaler|*in* X:**T**<br> *out* Y:**tensor(float)**|1+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|TreeEnsemble|*in* X:**T**<br> *out* Y:**T**|5+|**T** = tensor(double), tensor(float)|
|TreeEnsembleClassifier|*in* X:**T1**<br> *out* Y:**T2**<br> *out* Z:**tensor(float)**|3+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int64), tensor(string)|
|||[1, 2]|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int64), tensor(string)|
|TreeEnsembleRegressor|*in* X:**T**<br> *out* Y:**tensor(float)**|3+|**T** = tensor(double), tensor(float)|

Просмотреть файл

@ -2925,6 +2925,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, int32_t, TreeEnsembleClassifier);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleRegressor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleRegressor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, float, TreeEnsemble);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, double, TreeEnsemble);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string, LabelEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_float, LabelEncoder);
@ -3043,6 +3045,10 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
TreeEnsembleRegressor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double,
TreeEnsembleRegressor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, float,
TreeEnsemble)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, double,
TreeEnsemble)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string,
LabelEncoder)>,

Просмотреть файл

@ -20,44 +20,48 @@ enum class OUTPUT_MODE {
ALL_SCORES
};
enum NODE_MODE : uint8_t {
LEAF = 1,
BRANCH_LEQ = 2,
BRANCH_LT = 4,
BRANCH_GTE = 6,
BRANCH_GT = 8,
BRANCH_EQ = 10,
BRANCH_NEQ = 12
enum NODE_MODE_ONNX : uint8_t {
BRANCH_LEQ = 0,
BRANCH_LT = 1,
BRANCH_GTE = 2,
BRANCH_GT = 3,
BRANCH_EQ = 4,
BRANCH_NEQ = 5,
BRANCH_MEMBER = 6,
LEAF = 7,
};
static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {
static inline NODE_MODE_ONNX MakeTreeNodeMode(const std::string& input) {
if (input == "BRANCH_LEQ") {
return NODE_MODE::BRANCH_LEQ;
return NODE_MODE_ONNX::BRANCH_LEQ;
}
if (input == "LEAF") {
return NODE_MODE::LEAF;
return NODE_MODE_ONNX::LEAF;
}
if (input == "BRANCH_LT") {
return NODE_MODE::BRANCH_LT;
return NODE_MODE_ONNX::BRANCH_LT;
}
if (input == "BRANCH_GTE") {
return NODE_MODE::BRANCH_GTE;
return NODE_MODE_ONNX::BRANCH_GTE;
}
if (input == "BRANCH_GT") {
return NODE_MODE::BRANCH_GT;
return NODE_MODE_ONNX::BRANCH_GT;
}
if (input == "BRANCH_EQ") {
return NODE_MODE::BRANCH_EQ;
return NODE_MODE_ONNX::BRANCH_EQ;
}
return NODE_MODE::BRANCH_NEQ;
if (input == "BRANCH_MEMBER") {
return NODE_MODE_ONNX::BRANCH_MEMBER;
}
return NODE_MODE_ONNX::BRANCH_NEQ;
}
enum class POST_EVAL_TRANSFORM {
NONE,
LOGISTIC,
SOFTMAX,
SOFTMAX_ZERO,
PROBIT
enum class POST_EVAL_TRANSFORM : int64_t {
NONE = 0,
LOGISTIC = 1,
SOFTMAX = 2,
SOFTMAX_ZERO = 3,
PROBIT = 4
};
static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) {
@ -76,11 +80,11 @@ static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) {
return POST_EVAL_TRANSFORM::PROBIT;
}
enum class AGGREGATE_FUNCTION {
AVERAGE,
SUM,
MIN,
MAX
enum class AGGREGATE_FUNCTION : int64_t {
AVERAGE = 0,
SUM = 1,
MIN = 2,
MAX = 3
};
static inline AGGREGATE_FUNCTION MakeAggregateFunction(const std::string& input) {

Просмотреть файл

@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cpu/ml/tree_ensemble.h"
#include "core/providers/cpu/ml/tree_ensemble_helper.h"
#include "core/common/inlined_containers_fwd.h"
namespace onnxruntime {
namespace ml {
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
TreeEnsemble,
5,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).MayInplace(0, 0),
TreeEnsemble<float>);
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
TreeEnsemble,
5,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()).MayInplace(0, 0),
TreeEnsemble<double>);
template <typename T>
TreeEnsemble<T>::TreeEnsemble(const OpKernelInfo& info) : OpKernel(info) {
if constexpr (std::is_same<T, double>::value) {
p_tree_ensemble_ = std::make_unique<detail::TreeEnsembleCommonV5<T, double>>();
} else {
p_tree_ensemble_ = std::make_unique<detail::TreeEnsembleCommonV5<T, float>>();
}
ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info));
}
template <typename T>
Status TreeEnsemble<T>::GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const {
InlinedVector<std::string> names{
"leaf_targetids", "leaf_weights", "membership_values", "nodes_falseleafs",
"nodes_falsenodeids", "nodes_featureids", "nodes_hitrates", "nodes_missing_value_tracks_true",
"nodes_modes", "nodes_splits", "nodes_trueleafs", "nodes_truenodeids"};
removable_attributes.swap(names);
return Status::OK();
}
template <typename T>
common::Status TreeEnsemble<T>::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
if (X->Shape().NumDimensions() == 0) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Input shape needs to be at least a single dimension.");
}
int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0];
Tensor* Y = context->Output(0, {N, p_tree_ensemble_->get_target_or_class_count()});
return p_tree_ensemble_->compute(context, X, Y, NULL);
}
} // namespace ml
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "tree_ensemble_common.h"
namespace onnxruntime {
namespace ml {
template <typename T>
class TreeEnsemble final : public OpKernel {
typedef T InputType; // input type
typedef float OutputType; // output type
public:
explicit TreeEnsemble(const OpKernelInfo& info);
common::Status Compute(OpKernelContext* context) const override;
Status GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const override;
private:
// Pointer on one instance of
// detail::TreeEnsembleCommonV5<T, ThresholdType>
// where ThresholdType is defined after accessing the attributes.
std::unique_ptr<detail::TreeEnsembleCommonAttributes> p_tree_ensemble_;
};
} // namespace ml
} // namespace onnxruntime

Просмотреть файл

@ -78,6 +78,40 @@ union PtrOrWeight {
} weight_data;
};
enum NODE_MODE_ORT : uint8_t {
LEAF = 1,
BRANCH_LEQ = 2,
BRANCH_LT = 4,
BRANCH_GTE = 6,
BRANCH_GT = 8,
BRANCH_EQ = 10,
BRANCH_NEQ = 12,
BRANCH_MEMBER = 14,
};
inline NODE_MODE_ORT Convert_NODE_MODE_ONNX_to_ORT(NODE_MODE_ONNX node_mode) {
switch (node_mode) {
case NODE_MODE_ONNX::LEAF:
return NODE_MODE_ORT::LEAF;
case NODE_MODE_ONNX::BRANCH_LEQ:
return NODE_MODE_ORT::BRANCH_LEQ;
case NODE_MODE_ONNX::BRANCH_LT:
return NODE_MODE_ORT::BRANCH_LT;
case NODE_MODE_ONNX::BRANCH_GTE:
return NODE_MODE_ORT::BRANCH_GTE;
case NODE_MODE_ONNX::BRANCH_GT:
return NODE_MODE_ORT::BRANCH_GT;
case NODE_MODE_ONNX::BRANCH_EQ:
return NODE_MODE_ORT::BRANCH_EQ;
case NODE_MODE_ONNX::BRANCH_NEQ:
return NODE_MODE_ORT::BRANCH_NEQ;
case NODE_MODE_ONNX::BRANCH_MEMBER:
return NODE_MODE_ORT::BRANCH_MEMBER;
default:
ORT_THROW("Unexpected value for node_mode");
};
}
template <typename T>
struct TreeNodeElement {
int feature_id;
@ -98,10 +132,10 @@ struct TreeNodeElement {
// weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also
// stored in `value_or_unique_weight`.
PtrOrWeight<T> truenode_or_weight;
uint8_t flags;
NODE_MODE_ORT flags;
inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); }
inline bool is_not_leaf() const { return !(flags & NODE_MODE::LEAF); }
inline NODE_MODE_ORT mode() const { return NODE_MODE_ORT(flags & 0xF); }
inline bool is_not_leaf() const { return !(flags & NODE_MODE_ORT::LEAF); }
inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; }
};

Просмотреть файл

@ -0,0 +1,321 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/inlined_containers.h"
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "ml_common.h"
#include "tree_ensemble_helper.h"
#include <vector>
namespace onnxruntime {
namespace ml {
namespace detail {
inline bool _isnan_(float x) { return std::isnan(x); }
inline bool _isnan_(double x) { return std::isnan(x); }
inline bool _isnan_(int64_t) { return false; }
inline bool _isnan_(int32_t) { return false; }
template <typename ThresholdType>
struct TreeEnsembleAttributesV3 {
TreeEnsembleAttributesV3() {}
TreeEnsembleAttributesV3(const OpKernelInfo& info, bool classifier) {
#if !defined(ORT_MINIMAL_BUILD)
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor));
if (classifier) {
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "class_weights_as_tensor", target_class_weights_as_tensor));
} else {
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "target_weights_as_tensor", target_class_weights_as_tensor));
}
#endif
aggregate_function = info.GetAttrOrDefault<std::string>("aggregate_function", "SUM");
base_values = info.GetAttrsOrDefault<float>("base_values");
nodes_falsenodeids = info.GetAttrsOrDefault<int64_t>("nodes_falsenodeids");
nodes_featureids = info.GetAttrsOrDefault<int64_t>("nodes_featureids");
nodes_missing_value_tracks_true = info.GetAttrsOrDefault<int64_t>("nodes_missing_value_tracks_true");
std::vector<std::string> nodes_modes_string = info.GetAttrsOrDefault<std::string>("nodes_modes");
nodes_modes.reserve(nodes_modes_string.size());
for (auto s : nodes_modes_string) {
nodes_modes.emplace_back(MakeTreeNodeMode(s));
}
nodes_nodeids = info.GetAttrsOrDefault<int64_t>("nodes_nodeids");
nodes_treeids = info.GetAttrsOrDefault<int64_t>("nodes_treeids");
nodes_truenodeids = info.GetAttrsOrDefault<int64_t>("nodes_truenodeids");
nodes_values = info.GetAttrsOrDefault<float>("nodes_values");
post_transform = info.GetAttrOrDefault<std::string>("post_transform", "NONE");
if (classifier) {
target_class_ids = info.GetAttrsOrDefault<int64_t>("class_ids");
target_class_nodeids = info.GetAttrsOrDefault<int64_t>("class_nodeids");
target_class_treeids = info.GetAttrsOrDefault<int64_t>("class_treeids");
target_class_weights = info.GetAttrsOrDefault<float>("class_weights");
classlabels_strings = info.GetAttrsOrDefault<std::string>("classlabels_strings");
classlabels_int64s = info.GetAttrsOrDefault<int64_t>("classlabels_int64s");
n_targets_or_classes = classlabels_strings.empty() ? classlabels_int64s.size()
: classlabels_strings.size();
} else {
n_targets_or_classes = info.GetAttrOrDefault<int64_t>("n_targets", 0);
target_class_ids = info.GetAttrsOrDefault<int64_t>("target_ids");
target_class_nodeids = info.GetAttrsOrDefault<int64_t>("target_nodeids");
target_class_treeids = info.GetAttrsOrDefault<int64_t>("target_treeids");
target_class_weights = info.GetAttrsOrDefault<float>("target_weights");
ORT_ENFORCE(n_targets_or_classes > 0);
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes_string.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() ||
nodes_falsenodeids.size() == nodes_values_as_tensor.size());
ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size());
ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size());
ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size());
ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty());
ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty());
ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty());
ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty());
ORT_ENFORCE(nodes_modes_string.size() < std::numeric_limits<uint32_t>::max());
}
}
std::string aggregate_function;
std::vector<float> base_values;
std::vector<ThresholdType> base_values_as_tensor;
int64_t n_targets_or_classes;
std::vector<int64_t> nodes_falsenodeids;
std::vector<int64_t> nodes_featureids;
std::vector<float> nodes_hitrates;
std::vector<ThresholdType> nodes_hitrates_as_tensor;
std::vector<int64_t> nodes_missing_value_tracks_true;
std::vector<NODE_MODE_ONNX> nodes_modes;
std::vector<int64_t> nodes_nodeids;
std::vector<int64_t> nodes_treeids;
std::vector<int64_t> nodes_truenodeids;
std::vector<float> nodes_values;
std::vector<ThresholdType> nodes_values_as_tensor;
std::string post_transform;
std::vector<int64_t> target_class_ids;
std::vector<int64_t> target_class_nodeids;
std::vector<int64_t> target_class_treeids;
std::vector<float> target_class_weights;
std::vector<ThresholdType> target_class_weights_as_tensor;
std::vector<std::string> classlabels_strings;
std::vector<int64_t> classlabels_int64s;
std::vector<int64_t> class_labels;
};
template <typename ThresholdType>
struct TreeEnsembleAttributesV5 {
TreeEnsembleAttributesV5() {}
TreeEnsembleAttributesV5(const OpKernelInfo& info) {
#if !defined(ORT_MINIMAL_BUILD)
std::vector<uint8_t> nodes_modes_i;
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "leaf_weights", leaf_weights));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "membership_values", membership_values));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates", nodes_hitrates));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_modes", nodes_modes_i));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_splits", nodes_splits));
nodes_modes.reserve(nodes_modes.size());
for (auto i : nodes_modes_i) {
nodes_modes.push_back(static_cast<NODE_MODE_ONNX>(i));
}
#else
// GetVectorAttrsOrDefault is not part of the minimal build.
// As a result, TreeEnsemble v5 cannot be available in this build.
ORT_THROW("TreeEnsemble(ai.onnx.ml==5) is not supported with the minimal build.");
#endif
aggregate_function = info.GetAttrOrDefault<int64_t>("aggregate_function", 1);
leaf_targetids = info.GetAttrsOrDefault<int64_t>("leaf_targetids");
n_targets = info.GetAttrOrDefault<int64_t>("n_targets", 0);
nodes_falseleafs = info.GetAttrsOrDefault<int64_t>("nodes_falseleafs");
nodes_falsenodeids = info.GetAttrsOrDefault<int64_t>("nodes_falsenodeids");
nodes_featureids = info.GetAttrsOrDefault<int64_t>("nodes_featureids");
nodes_missing_value_tracks_true = info.GetAttrsOrDefault<int64_t>("nodes_missing_value_tracks_true");
nodes_trueleafs = info.GetAttrsOrDefault<int64_t>("nodes_trueleafs");
nodes_truenodeids = info.GetAttrsOrDefault<int64_t>("nodes_truenodeids");
post_transform = info.GetAttrOrDefault<int64_t>("post_transform", 0);
tree_roots = info.GetAttrsOrDefault<int64_t>("tree_roots");
}
void convert_to_v3(TreeEnsembleAttributesV3<ThresholdType>& output) const {
// Doing all transformations to get the old format.
output.n_targets_or_classes = n_targets;
output.aggregate_function = aggregateFunctionToString();
output.post_transform = postTransformToString();
std::vector<std::vector<ThresholdType>> membership_values_by_id;
getMembershipValuesById(membership_values_by_id);
transformInputAllTrees(output, membership_values_by_id);
}
int64_t aggregate_function;
std::vector<int64_t> leaf_targetids;
std::vector<ThresholdType> leaf_weights;
std::vector<ThresholdType> membership_values;
int64_t n_targets;
std::vector<int64_t> nodes_falseleafs;
std::vector<int64_t> nodes_falsenodeids;
std::vector<int64_t> nodes_featureids;
std::vector<ThresholdType> nodes_hitrates;
std::vector<int64_t> nodes_missing_value_tracks_true;
std::vector<NODE_MODE_ONNX> nodes_modes;
std::vector<ThresholdType> nodes_splits;
std::vector<int64_t> nodes_trueleafs;
std::vector<int64_t> nodes_truenodeids;
int64_t post_transform;
std::vector<int64_t> tree_roots;
private:
// `membership_values` are seperated by NAN for different nodes
// It is more convenient to preserve the values for each node in a vector
// The vector would be empty for nodes that are not `BRANCH_MEMBER`
void getMembershipValuesById(std::vector<std::vector<ThresholdType>>& membership_values_by_id) const {
membership_values_by_id.clear();
membership_values_by_id.reserve(nodes_modes.size());
size_t curr_id = 0;
for (const auto node_mode : nodes_modes) {
membership_values_by_id.emplace_back();
if (node_mode != NODE_MODE_ONNX::BRANCH_MEMBER) {
continue;
}
while (curr_id < membership_values.size() && !_isnan_(membership_values[curr_id])) {
membership_values_by_id.back().push_back(membership_values[curr_id++]);
}
curr_id++;
}
}
std::string aggregateFunctionToString() const {
switch (aggregate_function) {
case static_cast<int64_t>(AGGREGATE_FUNCTION::AVERAGE):
return "AVERAGE";
case static_cast<int64_t>(AGGREGATE_FUNCTION::SUM):
return "SUM";
case static_cast<int64_t>(AGGREGATE_FUNCTION::MIN):
return "MIN";
case static_cast<int64_t>(AGGREGATE_FUNCTION::MAX):
return "MAX";
default:
ORT_THROW("Unknown value for aggregate_function.");
}
}
std::string postTransformToString() const {
switch (post_transform) {
case static_cast<int64_t>(POST_EVAL_TRANSFORM::NONE):
return "NONE";
case static_cast<int64_t>(POST_EVAL_TRANSFORM::SOFTMAX):
return "SOFTMAX";
case static_cast<int64_t>(POST_EVAL_TRANSFORM::LOGISTIC):
return "LOGISTIC";
case static_cast<int64_t>(POST_EVAL_TRANSFORM::SOFTMAX_ZERO):
return "SOFTMAX_ZERO";
case static_cast<int64_t>(POST_EVAL_TRANSFORM::PROBIT):
return "PROBIT";
default:
ORT_THROW("Unknown value for post_transform.");
}
}
int64_t transformInputOneTree(
const size_t curr_id, const int64_t curr_treeid, const int64_t curr_nodeid, const size_t curr_membership_value_id,
const bool is_leaf, std::vector<std::vector<ThresholdType>>& membership_values_by_id,
TreeEnsembleAttributesV3<ThresholdType>& output) const {
output.nodes_nodeids.push_back(curr_nodeid);
output.nodes_treeids.push_back(curr_treeid);
if (is_leaf) {
output.nodes_modes.push_back(NODE_MODE_ONNX::LEAF);
output.target_class_ids.push_back(leaf_targetids[curr_id]);
output.target_class_nodeids.push_back(curr_nodeid);
output.target_class_treeids.push_back(curr_treeid);
output.target_class_weights_as_tensor.push_back(leaf_weights[curr_id]);
// the below are irrelevant for a `LEAF`
output.nodes_featureids.push_back(0);
output.nodes_truenodeids.push_back(0);
output.nodes_falsenodeids.push_back(0);
output.nodes_values_as_tensor.push_back(0);
if (!nodes_hitrates.empty()) {
output.nodes_hitrates.push_back(0);
}
if (!nodes_missing_value_tracks_true.empty()) {
output.nodes_missing_value_tracks_true.push_back(0);
}
return curr_nodeid;
}
output.nodes_featureids.push_back(nodes_featureids[curr_id]);
if (!nodes_hitrates.empty()) {
output.nodes_hitrates_as_tensor.push_back(nodes_hitrates[curr_id]);
}
if (!nodes_missing_value_tracks_true.empty()) {
output.nodes_missing_value_tracks_true.push_back(nodes_missing_value_tracks_true[curr_id]);
}
// unroll `BRANCH_MEMBER` to a chain of `BRANCH_EQ`
if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER) {
output.nodes_modes.push_back(NODE_MODE_ONNX::BRANCH_EQ);
output.nodes_values_as_tensor.push_back(membership_values_by_id[curr_id][curr_membership_value_id]);
} else {
output.nodes_modes.push_back(nodes_modes[curr_id]);
output.nodes_values_as_tensor.push_back(nodes_splits[curr_id]);
}
size_t falsenodeid_id = output.nodes_falsenodeids.size();
output.nodes_falsenodeids.push_back(0); // change after pushing truenode subtree
int64_t true_nodeid = curr_nodeid + 1;
output.nodes_truenodeids.push_back(true_nodeid);
true_nodeid = transformInputOneTree(onnxruntime::narrow<size_t>(nodes_truenodeids[curr_id]),
curr_treeid, true_nodeid, 0U, nodes_trueleafs[curr_id] != 0,
membership_values_by_id, output);
int64_t false_nodeid = true_nodeid + 1;
output.nodes_falsenodeids[falsenodeid_id] = false_nodeid;
// if node is `BRANCH_MEMBER` we are unrolling the `membership_values` for that node
// therefore if the value is not the last, the `falsenode_id` must be pointing to the "same" node with a different membership value
// so in that case we are only moving the pointer for `membership_values`
//
// otherwise, the `falsenode_id` is pointing to the real falsenode subtree
if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER &&
curr_membership_value_id + 1 < membership_values_by_id[curr_id].size()) {
false_nodeid = transformInputOneTree(curr_id, curr_treeid, false_nodeid, curr_membership_value_id + 1, false,
membership_values_by_id, output);
} else {
false_nodeid = transformInputOneTree(onnxruntime::narrow<size_t>(nodes_falsenodeids[curr_id]),
curr_treeid, false_nodeid, 0U, nodes_falseleafs[curr_id] != 0,
membership_values_by_id, output);
}
return false_nodeid;
}
void transformInputAllTrees(TreeEnsembleAttributesV3<ThresholdType>& output,
std::vector<std::vector<ThresholdType>>& membership_values_by_id) const {
int64_t curr_treeid = 0;
for (const int64_t& tree_root : tree_roots) {
size_t tree_root_size_t = onnxruntime::narrow<size_t>(tree_root);
transformInputOneTree(tree_root_size_t, curr_treeid, 0, 0U,
nodes_falsenodeids[tree_root_size_t] == nodes_truenodeids[tree_root_size_t],
membership_values_by_id, output);
curr_treeid++;
}
}
};
} // namespace detail
} // namespace ml
} // namespace onnxruntime

Просмотреть файл

@ -3,15 +3,21 @@
#pragma once
#include "tree_ensemble_aggregator.h"
#include <mutex>
#include "core/platform/threadpool.h"
#include "tree_ensemble_helper.h"
#include "tree_ensemble_attribute.h"
#include "tree_ensemble_aggregator.h"
namespace onnxruntime {
namespace ml {
namespace detail {
/**
* These attributes are the kernel attributes. They are different from the onnx operator attributes
* to improve the computation efficiency. The initialization consists in moving the onnx attributes
* into the kernel attributes.
*/
class TreeEnsembleCommonAttributes {
public:
int64_t get_target_or_class_count() const { return this->n_targets_or_classes_; }
@ -57,27 +63,7 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
Status Init(int parallel_tree,
int parallel_tree_N,
int parallel_N,
const std::string& aggregate_function,
const std::vector<float>& base_values,
const std::vector<ThresholdType>& base_values_as_tensor,
int64_t n_targets_or_classes,
const std::vector<int64_t>& nodes_falsenodeids,
const std::vector<int64_t>& nodes_featureids,
const std::vector<float>& nodes_hitrates,
const std::vector<ThresholdType>& nodes_hitrates_as_tensor,
const std::vector<int64_t>& nodes_missing_value_tracks_true,
const std::vector<std::string>& nodes_modes,
const std::vector<int64_t>& nodes_nodeids,
const std::vector<int64_t>& nodes_treeids,
const std::vector<int64_t>& nodes_truenodeids,
const std::vector<float>& nodes_values,
const std::vector<ThresholdType>& nodes_values_as_tensor,
const std::string& post_transform,
const std::vector<int64_t>& target_class_ids,
const std::vector<int64_t>& target_class_nodeids,
const std::vector<int64_t>& target_class_treeids,
const std::vector<float>& target_class_weights,
const std::vector<ThresholdType>& target_class_weights_as_tensor);
const TreeEnsembleAttributesV3<ThresholdType>& attributes);
protected:
TreeNodeElement<ThresholdType>* ProcessTreeNodeLeave(TreeNodeElement<ThresholdType>* root,
@ -87,49 +73,52 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const;
private:
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids);
bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE_ONNX>& cmodes,
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, gsl::span<const int64_t> nodes_featureids,
gsl::span<const ThresholdType> nodes_values_as_tensor, gsl::span<const float> node_values,
gsl::span<const float> target_class_weights, gsl::span<const ThresholdType> target_class_weights_as_tensor,
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices);
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE_ONNX>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, gsl::span<const int64_t> nodes_featureids,
gsl::span<const ThresholdType> nodes_values_as_tensor, gsl::span<const float> node_values,
gsl::span<const int64_t> nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids, gsl::span<const float> target_class_weights,
gsl::span<const ThresholdType> target_class_weights_as_tensor, InlinedVector<std::pair<TreeNodeElementId, uint32_t>>& indices);
};
// Below is simple implementation of `bit_cast` as it is supported from c++20 and the current supported version is c++17
// Remove it when that is not the case
template <class To, class From>
std::enable_if_t<
sizeof(To) == sizeof(From) &&
std::is_trivially_copyable_v<From> &&
std::is_trivially_copyable_v<To>,
To>
// constexpr support needs compiler magic
static bit_cast(const From& src) noexcept {
static_assert(std::is_trivially_constructible_v<To>,
"This implementation additionally requires "
"destination type to be trivially constructible");
To dst;
std::memcpy(&dst, &src, sizeof(To));
return dst;
}
template <typename T>
std::conditional_t<sizeof(T) == sizeof(uint32_t), uint32_t, uint64_t> bit_cast_int(T val) {
if constexpr (sizeof(T) == sizeof(uint32_t)) {
return bit_cast<uint32_t>(val);
} else if constexpr (sizeof(T) == sizeof(uint64_t)) {
return bit_cast<uint64_t>(val);
}
static_assert(sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t));
}
template <typename InputType, typename ThresholdType, typename OutputType>
Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(const OpKernelInfo& info) {
std::vector<ThresholdType> base_values_as_tensor, nodes_hitrates_as_tensor,
nodes_values_as_tensor, target_weights_as_tensor;
#if !defined(ORT_MINIMAL_BUILD)
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "target_weights_as_tensor", target_weights_as_tensor));
#endif
return Init(
80,
128,
50,
info.GetAttrOrDefault<std::string>("aggregate_function", "SUM"),
info.GetAttrsOrDefault<float>("base_values"),
base_values_as_tensor,
info.GetAttrOrDefault<int64_t>("n_targets", 0),
info.GetAttrsOrDefault<int64_t>("nodes_falsenodeids"),
info.GetAttrsOrDefault<int64_t>("nodes_featureids"),
info.GetAttrsOrDefault<float>("nodes_hitrates"),
nodes_hitrates_as_tensor,
info.GetAttrsOrDefault<int64_t>("nodes_missing_value_tracks_true"),
info.GetAttrsOrDefault<std::string>("nodes_modes"),
info.GetAttrsOrDefault<int64_t>("nodes_nodeids"),
info.GetAttrsOrDefault<int64_t>("nodes_treeids"),
info.GetAttrsOrDefault<int64_t>("nodes_truenodeids"),
info.GetAttrsOrDefault<float>("nodes_values"),
nodes_values_as_tensor,
info.GetAttrOrDefault<std::string>("post_transform", "NONE"),
info.GetAttrsOrDefault<int64_t>("target_ids"),
info.GetAttrsOrDefault<int64_t>("target_nodeids"),
info.GetAttrsOrDefault<int64_t>("target_treeids"),
info.GetAttrsOrDefault<float>("target_weights"),
target_weights_as_tensor);
TreeEnsembleAttributesV3<ThresholdType> attributes(info, false);
return Init(80, 128, 50, attributes);
}
template <typename InputType, typename ThresholdType, typename OutputType>
@ -137,72 +126,35 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
int parallel_tree,
int parallel_tree_N,
int parallel_N,
const std::string& aggregate_function,
const std::vector<float>& base_values,
const std::vector<ThresholdType>& base_values_as_tensor,
int64_t n_targets_or_classes,
const std::vector<int64_t>& nodes_falsenodeids,
const std::vector<int64_t>& nodes_featureids,
const std::vector<float>& nodes_hitrates,
const std::vector<ThresholdType>& nodes_hitrates_as_tensor,
const std::vector<int64_t>& nodes_missing_value_tracks_true,
const std::vector<std::string>& nodes_modes,
const std::vector<int64_t>& nodes_nodeids,
const std::vector<int64_t>& nodes_treeids,
const std::vector<int64_t>& nodes_truenodeids,
const std::vector<float>& nodes_values,
const std::vector<ThresholdType>& nodes_values_as_tensor,
const std::string& post_transform,
const std::vector<int64_t>& target_class_ids,
const std::vector<int64_t>& target_class_nodeids,
const std::vector<int64_t>& target_class_treeids,
const std::vector<float>& target_class_weights,
const std::vector<ThresholdType>& target_class_weights_as_tensor) {
const TreeEnsembleAttributesV3<ThresholdType>& attributes) {
parallel_tree_ = parallel_tree;
parallel_tree_N_ = parallel_tree_N;
parallel_N_ = parallel_N;
ORT_ENFORCE(n_targets_or_classes > 0);
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() ||
nodes_falsenodeids.size() == nodes_values_as_tensor.size());
ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size());
ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size());
ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size());
ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty());
ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty());
ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty());
ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty());
aggregate_function_ = MakeAggregateFunction(aggregate_function);
post_transform_ = MakeTransform(post_transform);
if (!base_values_as_tensor.empty()) {
ORT_ENFORCE(base_values.empty());
base_values_ = base_values_as_tensor;
aggregate_function_ = MakeAggregateFunction(attributes.aggregate_function);
post_transform_ = MakeTransform(attributes.post_transform);
if (!attributes.base_values_as_tensor.empty()) {
ORT_ENFORCE(attributes.base_values.empty());
base_values_ = attributes.base_values_as_tensor;
} else {
base_values_.reserve(base_values.size());
for (size_t i = 0, limit = base_values.size(); i < limit; ++i) {
base_values_.push_back(static_cast<ThresholdType>(base_values[i]));
base_values_.reserve(attributes.base_values.size());
for (size_t i = 0, limit = attributes.base_values.size(); i < limit; ++i) {
base_values_.push_back(static_cast<ThresholdType>(attributes.base_values[i]));
}
}
n_targets_or_classes_ = n_targets_or_classes;
n_targets_or_classes_ = attributes.n_targets_or_classes;
max_tree_depth_ = 1000;
ORT_ENFORCE(nodes_modes.size() < std::numeric_limits<uint32_t>::max());
// Additional members
size_t limit;
uint32_t i;
InlinedVector<NODE_MODE> cmodes;
cmodes.reserve(nodes_modes.size());
InlinedVector<NODE_MODE_ONNX> cmodes;
cmodes.reserve(attributes.nodes_modes.size());
same_mode_ = true;
int fpos = -1;
for (i = 0, limit = nodes_modes.size(); i < limit; ++i) {
cmodes.push_back(MakeTreeNodeMode(nodes_modes[i]));
if (cmodes[i] == NODE_MODE::LEAF) continue;
for (i = 0, limit = attributes.nodes_modes.size(); i < limit; ++i) {
cmodes.push_back(attributes.nodes_modes[i]);
if (cmodes[i] == NODE_MODE_ONNX::LEAF) continue;
if (fpos == -1) {
fpos = static_cast<int>(i);
continue;
@ -210,7 +162,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
if (cmodes[i] != cmodes[fpos]) same_mode_ = false;
}
n_nodes_ = nodes_treeids.size();
n_nodes_ = attributes.nodes_treeids.size();
limit = static_cast<size_t>(n_nodes_);
InlinedVector<TreeNodeElementId> node_tree_ids;
node_tree_ids.reserve(limit);
@ -227,7 +179,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
// Build node_tree_ids and node_tree_ids_map and truenode_ids and falsenode_ids
for (i = 0; i < limit; ++i) {
TreeNodeElementId node_tree_id{static_cast<int>(nodes_treeids[i]), static_cast<int>(nodes_nodeids[i])};
TreeNodeElementId node_tree_id{static_cast<int>(attributes.nodes_treeids[i]), static_cast<int>(attributes.nodes_nodeids[i])};
auto p = node_tree_ids_map.insert(std::pair<TreeNodeElementId, size_t>(node_tree_id, i));
if (!p.second) {
ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there.");
@ -237,13 +189,13 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
TreeNodeElementId coor;
for (i = 0; i < limit; ++i) {
if (cmodes[i] == NODE_MODE::LEAF) {
if (cmodes[i] == NODE_MODE_ONNX::LEAF) {
truenode_ids.push_back(0);
falsenode_ids.push_back(0);
} else {
TreeNodeElementId& node_tree_id = node_tree_ids[i];
coor.tree_id = node_tree_id.tree_id;
coor.node_id = static_cast<int>(nodes_truenodeids[i]);
coor.node_id = static_cast<int>(attributes.nodes_truenodeids[i]);
ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_));
auto found = node_tree_ids_map.find(coor);
@ -255,7 +207,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
}
truenode_ids.emplace_back(found->second);
coor.node_id = static_cast<int>(nodes_falsenodeids[i]);
coor.node_id = static_cast<int>(attributes.nodes_falsenodeids[i]);
ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_));
found = node_tree_ids_map.find(coor);
if (found == node_tree_ids_map.end()) {
@ -270,41 +222,38 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
}
}
// Sort targets
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
indices.reserve(attributes.target_class_nodeids.size());
for (i = 0, limit = attributes.target_class_nodeids.size(); i < limit; i++) {
indices.emplace_back(
TreeNodeElementId{attributes.target_class_treeids[i], attributes.target_class_nodeids[i]}, i);
}
std::sort(indices.begin(), indices.end());
// Let's construct nodes_ such that the false branch is always the next element in nodes_.
// updated_mapping will translates the old position of each node to the new node position in nodes_.
std::vector<size_t> updated_mapping(nodes_treeids.size(), 0);
std::vector<size_t> updated_mapping(attributes.nodes_treeids.size(), 0);
int64_t previous_tree_id = -1;
for (i = 0; i < n_nodes_; ++i) {
if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) {
// New tree.
int64_t tree_id = node_tree_ids[i].tree_id;
size_t root_position =
AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values,
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
AddNodes(i, cmodes, truenode_ids, falsenode_ids, attributes.nodes_featureids, attributes.nodes_values_as_tensor, attributes.nodes_values,
attributes.nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
attributes.target_class_weights, attributes.target_class_weights_as_tensor, indices);
roots_.push_back(&nodes_[root_position]);
previous_tree_id = tree_id;
}
}
n_trees_ = roots_.size();
if (((int64_t)nodes_.size()) != n_nodes_) {
ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ").");
}
// Sort targets
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
indices.reserve(target_class_nodeids.size());
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
indices.emplace_back(
std::pair<TreeNodeElementId, uint32_t>(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i));
}
std::sort(indices.begin(), indices.end());
TreeNodeElementId ind;
SparseValue<ThresholdType> w;
size_t indi;
for (indi = 0, limit = target_class_nodeids.size(); indi < limit; ++indi) {
for (indi = 0, limit = attributes.target_class_nodeids.size(); indi < limit; ++indi) {
ind = indices[indi].first;
i = indices[indi].second;
auto found = node_tree_ids_map.find(ind);
@ -319,9 +268,10 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
// ORT_THROW("Node ", ind.tree_id, "-", ind.node_id, " is not a leaf.");
continue;
}
w.i = target_class_ids[i];
w.value = target_class_weights_as_tensor.empty() ? static_cast<ThresholdType>(target_class_weights[i])
: target_class_weights_as_tensor[i];
w.i = attributes.target_class_ids[i];
w.value = attributes.target_class_weights_as_tensor.empty()
? static_cast<ThresholdType>(attributes.target_class_weights[i])
: attributes.target_class_weights_as_tensor[i];
if (leaf.truenode_or_weight.weight_data.n_weights == 0) {
leaf.truenode_or_weight.weight_data.weight = static_cast<int32_t>(weights_.size());
leaf.value_or_unique_weight = w.value;
@ -331,7 +281,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
}
has_missing_tracks_ = false;
for (auto itm = nodes_missing_value_tracks_true.begin(); itm != nodes_missing_value_tracks_true.end(); ++itm) {
for (auto itm = attributes.nodes_missing_value_tracks_true.begin(); itm != attributes.nodes_missing_value_tracks_true.end(); ++itm) {
if (*itm) {
has_missing_tracks_ = true;
break;
@ -341,13 +291,58 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
return Status::OK();
}
template <typename InputType, typename ThresholdType, typename OutputType>
bool TreeEnsembleCommon<InputType, ThresholdType, OutputType>::CheckIfSubtreesAreEqual(
const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE_ONNX>& cmodes,
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, gsl::span<const int64_t> nodes_featureids,
gsl::span<const ThresholdType> nodes_values_as_tensor, gsl::span<const float> node_values,
gsl::span<const float> target_class_weights, gsl::span<const ThresholdType> target_class_weights_as_tensor,
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices) {
// Leaves have values set at 0
if (cmodes[left_id] != cmodes[right_id] || nodes_featureids[left_id] != nodes_featureids[right_id] ||
(!nodes_values_as_tensor.empty() && nodes_values_as_tensor[left_id] != nodes_values_as_tensor[right_id]) ||
(nodes_values_as_tensor.empty() && node_values[left_id] != node_values[right_id])) {
return false;
}
if (cmodes[left_id] == NODE_MODE_ONNX::LEAF) {
const auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[left_id], uint32_t(0)))->second;
const auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[right_id], uint32_t(0)))->second;
if (target_class_weights_as_tensor.empty()) {
return target_class_weights[left_target_node] == target_class_weights[right_target_node];
} else {
return target_class_weights_as_tensor[left_target_node] == target_class_weights_as_tensor[right_target_node];
}
}
return CheckIfSubtreesAreEqual(falsenode_ids[left_id], falsenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids,
nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices) &&
CheckIfSubtreesAreEqual(truenode_ids[left_id], truenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids,
nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices);
}
inline void UpdateThreshold(double val, double& mask) {
uint64_t new_mask = bit_cast<uint64_t>(mask) | (1ll << (static_cast<uint32_t>(val) - 1));
mask = bit_cast<double>(new_mask);
}
inline void UpdateThreshold(float val, float& mask) {
uint32_t new_mask = bit_cast<uint32_t>(mask) | (1 << (static_cast<uint32_t>(val) - 1));
mask = bit_cast<float>(new_mask);
}
#define BITCOUNT(T) int64_t(sizeof(T) * 8)
#define CANMASK(v, T) (v >= 1 && v <= BITCOUNT(T)) && v == std::floor(v)
template <typename InputType, typename ThresholdType, typename OutputType>
size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id,
const InlinedVector<TreeNodeElementId>& node_tree_ids) {
const size_t i, const InlinedVector<NODE_MODE_ONNX>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, gsl::span<const int64_t> nodes_featureids,
gsl::span<const ThresholdType> nodes_values_as_tensor, gsl::span<const float> node_values,
gsl::span<const int64_t> nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id,
const InlinedVector<TreeNodeElementId>& node_tree_ids, gsl::span<const float> target_class_weights,
gsl::span<const ThresholdType> target_class_weights_as_tensor, InlinedVector<std::pair<TreeNodeElementId, uint32_t>>& indices) {
// Validate this index maps to the same tree_id as the one we should be building.
if (node_tree_ids[i].tree_id != tree_id) {
ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i);
@ -364,28 +359,59 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
updated_mapping[i] = node_pos;
TreeNodeElement<ThresholdType> node;
node.flags = static_cast<uint8_t>(cmodes[i]);
node.flags = Convert_NODE_MODE_ONNX_to_ORT(cmodes[i]);
node.feature_id = static_cast<int>(nodes_featureids[i]);
if (node.feature_id > max_feature_id_) {
max_feature_id_ = node.feature_id;
}
node.value_or_unique_weight =
nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
node.value_or_unique_weight = 0;
const ThresholdType node_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
if (node.flags == NODE_MODE_ORT::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) {
UpdateThreshold(node_threshold, node.value_or_unique_weight);
node.flags = NODE_MODE_ORT::BRANCH_MEMBER;
} else {
node.value_or_unique_weight = node_threshold;
}
if (i < static_cast<size_t>(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) {
node.flags |= static_cast<uint8_t>(MissingTrack::kTrue);
node.flags = static_cast<NODE_MODE_ORT>(static_cast<uint8_t>(node.flags) | static_cast<uint8_t>(MissingTrack::kTrue));
}
nodes_.push_back(std::move(node));
if (nodes_[node_pos].is_not_leaf()) {
size_t falsenode_id = falsenode_ids[i];
// Categoricals are represented as a chain of `EQ` nodes where the subtree for the true child is identical for all nodes in the chain
// Below we are folding together these nodes into one of mode `BRANCH_MEMBER`
// The threshold of this node should be interpreted as a bitmask showing which categoricals values were found in the chain
// Afterwards, when looking whether a feature is included we can do an `and` with the mask of the node
// and the one of the feature (the mask has only one bit set on the place for its value)
// Beware that if a category is bigger than the threshold type, the node stays as `EQ` and no combination is done
if (nodes_[node_pos].flags == NODE_MODE_ORT::BRANCH_MEMBER) {
ThresholdType falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];
while (cmodes[falsenode_id] == NODE_MODE_ONNX::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] &&
CANMASK(falsenode_threshold, ThresholdType) &&
CheckIfSubtreesAreEqual(truenode_ids[i], truenode_ids[falsenode_id], tree_id, cmodes, truenode_ids, falsenode_ids,
nodes_featureids, nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices)) {
UpdateThreshold(falsenode_threshold, nodes_[node_pos].value_or_unique_weight);
falsenode_id = falsenode_ids[falsenode_id];
falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];
}
}
size_t false_branch =
AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
AddNodes(falsenode_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
if (false_branch != node_pos + 1) {
ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ",
static_cast<int>(nodes_[node_pos].flags));
}
size_t true_branch =
AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
// We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_.
// nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch];
nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch];
@ -684,10 +710,12 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
} \
}
inline bool _isnan_(float x) { return std::isnan(x); }
inline bool _isnan_(double x) { return std::isnan(x); }
inline bool _isnan_(int64_t) { return false; }
inline bool _isnan_(int32_t) { return false; }
// Check whether the feature value is set true in the mask
template <typename T1, typename T2>
inline bool SetMembershipCheck(T1 val, T2 mask) {
const int64_t val_as_int = static_cast<int64_t>(val);
return CANMASK(val, T2) && (((1ll << (val_as_int - 1)) & bit_cast_int(mask)) != 0);
}
template <typename InputType, typename ThresholdType, typename OutputType>
TreeNodeElement<ThresholdType>*
@ -696,7 +724,7 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
InputType val;
if (same_mode_) {
switch (root->mode()) {
case NODE_MODE::BRANCH_LEQ:
case NODE_MODE_ORT::BRANCH_LEQ:
if (has_missing_tracks_) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
@ -711,22 +739,36 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
}
}
break;
case NODE_MODE::BRANCH_LT:
case NODE_MODE_ORT::BRANCH_LT:
TREE_FIND_VALUE(<)
break;
case NODE_MODE::BRANCH_GTE:
case NODE_MODE_ORT::BRANCH_GTE:
TREE_FIND_VALUE(>=)
break;
case NODE_MODE::BRANCH_GT:
case NODE_MODE_ORT::BRANCH_GT:
TREE_FIND_VALUE(>)
break;
case NODE_MODE::BRANCH_EQ:
case NODE_MODE_ORT::BRANCH_EQ:
TREE_FIND_VALUE(==)
break;
case NODE_MODE::BRANCH_NEQ:
case NODE_MODE_ORT::BRANCH_NEQ:
TREE_FIND_VALUE(!=)
break;
case NODE_MODE::LEAF:
case NODE_MODE_ORT::BRANCH_MEMBER:
if (has_missing_tracks_) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
}
} else {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = SetMembershipCheck(val, root->value_or_unique_weight) ? root->truenode_or_weight.ptr : root + 1;
}
}
case NODE_MODE_ORT::LEAF:
break;
}
} else { // Different rules to compare to node thresholds.
@ -735,31 +777,36 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
val = x_data[root->feature_id];
threshold = root->value_or_unique_weight;
switch (root->mode()) {
case NODE_MODE::BRANCH_LEQ:
case NODE_MODE_ORT::BRANCH_LEQ:
root = val <= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_LT:
case NODE_MODE_ORT::BRANCH_LT:
root = val < threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_GTE:
case NODE_MODE_ORT::BRANCH_GTE:
root = val >= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_GT:
case NODE_MODE_ORT::BRANCH_GT:
root = val > threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_EQ:
case NODE_MODE_ORT::BRANCH_EQ:
root = val == threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_NEQ:
case NODE_MODE_ORT::BRANCH_NEQ:
root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::LEAF:
case NODE_MODE_ORT::BRANCH_MEMBER:
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE_ORT::LEAF:
return root;
}
}
@ -786,67 +833,13 @@ class TreeEnsembleCommonClassifier : public TreeEnsembleCommon<InputType, Thresh
Status Init(int parallel_tree,
int parallel_tree_N,
int parallel_N,
const std::string& aggregate_function,
const std::vector<float>& base_values,
const std::vector<ThresholdType>& base_values_as_tensor,
const std::vector<int64_t>& nodes_falsenodeids,
const std::vector<int64_t>& nodes_featureids,
const std::vector<float>& nodes_hitrates,
const std::vector<ThresholdType>& nodes_hitrates_as_tensor,
const std::vector<int64_t>& nodes_missing_value_tracks_true,
const std::vector<std::string>& nodes_modes,
const std::vector<int64_t>& nodes_nodeids,
const std::vector<int64_t>& nodes_treeids,
const std::vector<int64_t>& nodes_truenodeids,
const std::vector<float>& nodes_values,
const std::vector<ThresholdType>& nodes_values_as_tensor,
const std::string& post_transform,
const std::vector<int64_t>& class_ids,
const std::vector<int64_t>& class_nodeids,
const std::vector<int64_t>& class_treeids,
const std::vector<float>& class_weights,
const std::vector<ThresholdType>& class_weights_as_tensor,
const std::vector<std::string>& classlabels_strings,
const std::vector<int64_t>& classlabels_int64s);
const TreeEnsembleAttributesV3<ThresholdType>& attributes);
};
template <typename InputType, typename ThresholdType, typename OutputType>
Status TreeEnsembleCommonClassifier<InputType, ThresholdType, OutputType>::Init(const OpKernelInfo& info) {
std::vector<ThresholdType> base_values_as_tensor, nodes_hitrates_as_tensor,
nodes_values_as_tensor, class_weights_as_tensor;
#if !defined(ORT_MINIMAL_BUILD)
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "class_weights_as_tensor", class_weights_as_tensor));
#endif
return Init(
80,
128,
50,
info.GetAttrOrDefault<std::string>("aggregate_function", "SUM"),
info.GetAttrsOrDefault<float>("base_values"),
base_values_as_tensor,
info.GetAttrsOrDefault<int64_t>("nodes_falsenodeids"),
info.GetAttrsOrDefault<int64_t>("nodes_featureids"),
info.GetAttrsOrDefault<float>("nodes_hitrates"),
nodes_hitrates_as_tensor,
info.GetAttrsOrDefault<int64_t>("nodes_missing_value_tracks_true"),
info.GetAttrsOrDefault<std::string>("nodes_modes"),
info.GetAttrsOrDefault<int64_t>("nodes_nodeids"),
info.GetAttrsOrDefault<int64_t>("nodes_treeids"),
info.GetAttrsOrDefault<int64_t>("nodes_truenodeids"),
info.GetAttrsOrDefault<float>("nodes_values"),
nodes_values_as_tensor,
info.GetAttrOrDefault<std::string>("post_transform", "NONE"),
info.GetAttrsOrDefault<int64_t>("class_ids"),
info.GetAttrsOrDefault<int64_t>("class_nodeids"),
info.GetAttrsOrDefault<int64_t>("class_treeids"),
info.GetAttrsOrDefault<float>("class_weights"),
class_weights_as_tensor,
info.GetAttrsOrDefault<std::string>("classlabels_strings"),
info.GetAttrsOrDefault<int64_t>("classlabels_int64s"));
TreeEnsembleAttributesV3<ThresholdType> attributes(info, true);
return Init(80, 128, 50, attributes);
}
template <typename InputType, typename ThresholdType, typename OutputType>
@ -854,65 +847,20 @@ Status TreeEnsembleCommonClassifier<InputType, ThresholdType, OutputType>::Init(
int parallel_tree,
int parallel_tree_N,
int parallel_N,
const std::string& aggregate_function,
const std::vector<float>& base_values,
const std::vector<ThresholdType>& base_values_as_tensor,
const std::vector<int64_t>& nodes_falsenodeids,
const std::vector<int64_t>& nodes_featureids,
const std::vector<float>& nodes_hitrates,
const std::vector<ThresholdType>& nodes_hitrates_as_tensor,
const std::vector<int64_t>& nodes_missing_value_tracks_true,
const std::vector<std::string>& nodes_modes,
const std::vector<int64_t>& nodes_nodeids,
const std::vector<int64_t>& nodes_treeids,
const std::vector<int64_t>& nodes_truenodeids,
const std::vector<float>& nodes_values,
const std::vector<ThresholdType>& nodes_values_as_tensor,
const std::string& post_transform,
const std::vector<int64_t>& class_ids,
const std::vector<int64_t>& class_nodeids,
const std::vector<int64_t>& class_treeids,
const std::vector<float>& class_weights,
const std::vector<ThresholdType>& class_weights_as_tensor,
const std::vector<std::string>& classlabels_strings,
const std::vector<int64_t>& classlabels_int64s) {
auto status = TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
parallel_tree,
parallel_tree_N,
parallel_N,
aggregate_function,
base_values,
base_values_as_tensor,
classlabels_strings.empty() ? classlabels_int64s.size()
: classlabels_strings.size(),
nodes_falsenodeids,
nodes_featureids,
nodes_hitrates,
nodes_hitrates_as_tensor,
nodes_missing_value_tracks_true,
nodes_modes,
nodes_nodeids,
nodes_treeids,
nodes_truenodeids,
nodes_values,
nodes_values_as_tensor,
post_transform,
class_ids,
class_nodeids,
class_treeids,
class_weights,
class_weights_as_tensor);
const TreeEnsembleAttributesV3<ThresholdType>& attributes) {
auto status = TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(parallel_tree, parallel_tree_N, parallel_N, attributes);
ORT_RETURN_IF_ERROR(status);
classlabels_strings_ = classlabels_strings;
classlabels_int64s_ = classlabels_int64s;
classlabels_strings_ = attributes.classlabels_strings;
classlabels_int64s_ = attributes.classlabels_int64s;
InlinedHashSet<int64_t> weights_classes;
weights_classes.reserve(class_ids.size());
weights_classes.reserve(attributes.target_class_ids.size());
weights_are_all_positive_ = true;
for (size_t i = 0, end = class_ids.size(); i < end; ++i) {
weights_classes.insert(class_ids[i]);
if (weights_are_all_positive_ && (!class_weights.empty() ? class_weights[i] : class_weights_as_tensor[i]) < 0)
for (size_t i = 0, end = attributes.target_class_ids.size(); i < end; ++i) {
weights_classes.insert(attributes.target_class_ids[i]);
if (weights_are_all_positive_ && (!attributes.target_class_weights.empty() ? attributes.target_class_weights[i]
: attributes.target_class_weights_as_tensor[i]) < 0)
weights_are_all_positive_ = false;
}
binary_case_ = this->n_targets_or_classes_ == 2 && weights_classes.size() == 1;
@ -957,6 +905,43 @@ Status TreeEnsembleCommonClassifier<InputType, ThresholdType, OutputType>::compu
return Status::OK();
}
template <typename IOType, typename ThresholdType>
class TreeEnsembleCommonV5 : public TreeEnsembleCommon<IOType, ThresholdType, IOType> {
public:
virtual Status Init(const OpKernelInfo& info);
Status Init(int parallel_tree,
int parallel_tree_N,
int parallel_N,
const TreeEnsembleAttributesV5<ThresholdType>& attributes);
};
template <typename IOType, typename ThresholdType>
Status TreeEnsembleCommonV5<IOType, ThresholdType>::Init(const OpKernelInfo& info) {
TreeEnsembleAttributesV5<ThresholdType> attributes(info);
return Init(80, 128, 50, attributes);
}
template <typename IOType, typename ThresholdType>
Status TreeEnsembleCommonV5<IOType, ThresholdType>::Init(
int parallel_tree,
int parallel_tree_N,
int parallel_N,
const TreeEnsembleAttributesV5<ThresholdType>& attributes) {
TreeEnsembleAttributesV3<ThresholdType> attributes_v3;
attributes.convert_to_v3(attributes_v3);
attributes_v3.base_values.clear();
attributes_v3.base_values_as_tensor.clear();
attributes_v3.nodes_hitrates.clear();
attributes_v3.nodes_values.clear();
attributes_v3.target_class_weights.clear();
auto status = TreeEnsembleCommon<IOType, ThresholdType, IOType>::Init(parallel_tree, parallel_tree_N, parallel_N, attributes_v3);
ORT_RETURN_IF_ERROR(status);
return Status::OK();
}
} // namespace detail
} // namespace ml
} // namespace onnxruntime

Просмотреть файл

@ -5,63 +5,53 @@
#include "core/providers/cpu/ml/tree_ensemble_helper.h"
#include "core/common/common.h"
#include "core/common/safeint.h"
#include "onnx/defs/tensor_proto_util.h"
#include "core/framework/tensorprotoutils.h"
using namespace ::onnxruntime::common;
using namespace std;
namespace onnxruntime {
namespace ml {
Status GetNumberOfElementsAttrsOrDefault(const OpKernelInfo& info, const std::string& name,
ONNX_NAMESPACE::TensorProto_DataType proto_type,
size_t& n_elements, ONNX_NAMESPACE::TensorProto& proto) {
auto status = info.GetAttr(name, &proto);
if (!status.IsOK()) {
// Attribute is missing, n_elements is set to 0.
n_elements = 0;
return Status::OK();
}
auto n_dims = proto.dims_size();
if (n_dims == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attribute:'", name, "' is specified but is empty.");
}
ORT_ENFORCE(n_dims == 1, "Attribute '", name, "' must be a vector.");
ORT_ENFORCE(proto.data_type() == proto_type,
"Unexpected type (", proto.data_type(), "(for attribute '", name, "'.");
n_elements = onnxruntime::narrow<size_t>(proto.dims()[0]);
ORT_ENFORCE(n_elements > 0, "Attribute '", name, "' has one dimension but is empty.");
return Status::OK();
}
template <typename TH>
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name,
ONNX_NAMESPACE::TensorProto_DataType proto_type, std::vector<TH>& data) {
if (proto_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE) {
ORT_ENFORCE((std::is_same<double, TH>::value));
} else if (proto_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) {
ORT_ENFORCE((std::is_same<float, TH>::value));
} else {
ORT_NOT_IMPLEMENTED("GetVectorAttrsOrDefault not implemented for type ", proto_type);
}
template <typename T>
Status GetAnyVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<T>& data) {
ONNX_NAMESPACE::TensorProto proto;
size_t n_elements;
data.clear();
ORT_THROW_IF_ERROR(GetNumberOfElementsAttrsOrDefault(info, name, proto_type, n_elements, proto));
if (n_elements == 0) {
auto result = info.GetAttr(name, &proto);
SafeInt<int64_t> n_elements(1);
for (auto dim : proto.dims()) {
n_elements *= dim;
}
if (proto.dims().empty()) {
return Status::OK();
}
data = ONNX_NAMESPACE::ParseData<TH>(&proto);
const SafeInt<size_t> tensor_size(n_elements);
data.clear();
data.resize(tensor_size);
result = utils::UnpackTensor<T>(proto, std::filesystem::path(), data.data(), tensor_size);
ORT_ENFORCE(result.IsOK(), "TreeEnsemble could not unpack tensor attribute ", name);
return Status::OK();
}
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<double>& data) {
return GetVectorAttrsOrDefault(info, name, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE, data);
return GetAnyVectorAttrsOrDefault(info, name, data);
}
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<float>& data) {
return GetVectorAttrsOrDefault(info, name, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, data);
return GetAnyVectorAttrsOrDefault(info, name, data);
}
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<MLFloat16>& data) {
return GetAnyVectorAttrsOrDefault(info, name, data);
}
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<uint8_t>& data) {
return GetAnyVectorAttrsOrDefault(info, name, data);
}
} // namespace ml

Просмотреть файл

@ -13,6 +13,8 @@ namespace ml {
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<double>& data);
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<float>& data);
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<MLFloat16>& data);
Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector<uint8_t>& data);
} // namespace ml
} // namespace onnxruntime

Просмотреть файл

@ -132,6 +132,7 @@ DEFAULT_OP_BLOCK_LIST = [
"Scaler",
"TreeEnsembleClassifier",
"TreeEnsembleRegressor",
"TreeEnsemble",
"ZipMap",
"NonMaxSuppression",
"TopK",

Просмотреть файл

@ -0,0 +1,294 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
static ONNX_NAMESPACE::TensorProto make_tensor(std::vector<double> array, std::string name) {
ONNX_NAMESPACE::TensorProto array_as_tensor;
array_as_tensor.set_name(name);
array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE);
array_as_tensor.add_dims(array.size());
for (auto v : array) {
array_as_tensor.add_double_data(v);
}
return array_as_tensor;
}
static ONNX_NAMESPACE::TensorProto make_tensor(std::vector<float> array, std::string name) {
ONNX_NAMESPACE::TensorProto array_as_tensor;
array_as_tensor.set_name(name);
array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
array_as_tensor.add_dims(array.size());
for (auto v : array) {
array_as_tensor.add_float_data(v);
}
return array_as_tensor;
}
static ONNX_NAMESPACE::TensorProto make_tensor(std::vector<uint8_t> array, std::string name) {
ONNX_NAMESPACE::TensorProto array_as_tensor;
array_as_tensor.set_name(name);
array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
array_as_tensor.add_dims(array.size());
for (const auto v : array) {
array_as_tensor.add_int32_data(v);
}
return array_as_tensor;
}
template <typename T>
void _multiply_update_array(std::vector<T>& data, int n, T inc = 0) {
std::vector<T> copy = data;
data.resize(copy.size() * n);
T cst = 0;
for (int i = 0; i < n; ++i) {
for (size_t j = 0; j < copy.size(); ++j) {
data[j + i * copy.size()] = copy[j] + cst;
}
cst += inc;
}
}
template <typename T>
void _multiply_update_childnode(std::vector<T>& childnodes, std::vector<T>& childleafs, std::vector<T>& otherchildleafs, int n) {
int64_t leafs_cnt = 0;
int64_t nodes_cnt = childnodes.size();
for (auto& childleaf : childleafs) {
if (childleaf) {
leafs_cnt++;
}
}
for (auto& childleaf : otherchildleafs) {
if (childleaf) {
leafs_cnt++;
}
}
std::vector<T> copy = childnodes;
childnodes.resize(copy.size() * n);
T leafs_cst = 0;
T nodes_cst = 0;
for (int i = 0; i < n; ++i) {
for (size_t j = 0; j < copy.size(); ++j) {
T curr_inc = childleafs[j] ? leafs_cst : nodes_cst;
childnodes[j + i * copy.size()] = copy[j] + curr_inc;
}
leafs_cst += leafs_cnt;
nodes_cst += nodes_cnt;
}
}
template <typename T>
void _multiply_arrays_values(std::vector<T>& data, int64_t val) {
for (auto& curr : data) {
curr *= val;
}
}
template <typename T>
void GenTreeAndRunTest(const std::vector<T>& X, const std::vector<T>& Y, const int64_t& aggregate_function, int n_trees = 1) {
OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain);
int64_t n_targets = 2;
int64_t post_transform = 0;
std::vector<int64_t> tree_roots = {0};
std::vector<int64_t> nodes_featureids = {0, 0, 0};
std::vector<uint8_t> nodes_modes = {0, 0, 0};
std::vector<T> nodes_splits = {3.14f, 1.2f, 4.2f};
std::vector<int64_t> nodes_truenodeids = {1, 0, 1};
std::vector<int64_t> nodes_trueleafs = {0, 1, 1};
std::vector<int64_t> nodes_falsenodeids = {2, 2, 3};
std::vector<int64_t> nodes_falseleafs = {0, 1, 1};
std::vector<int64_t> leaf_targetids = {0, 1, 0, 1};
std::vector<T> leaf_weights = {5.23f, 12.12f, -12.23f, 7.21f};
if (n_trees > 1) {
// Multiplies the number of trees to test the parallelization by trees.
_multiply_update_array(tree_roots, n_trees, (int64_t)nodes_truenodeids.size());
_multiply_update_array(nodes_featureids, n_trees);
_multiply_update_childnode(nodes_truenodeids, nodes_trueleafs, nodes_falseleafs, n_trees);
_multiply_update_childnode(nodes_falsenodeids, nodes_falseleafs, nodes_trueleafs, n_trees);
_multiply_update_array(nodes_trueleafs, n_trees);
_multiply_update_array(nodes_falseleafs, n_trees);
_multiply_update_array(leaf_targetids, n_trees);
_multiply_update_array(nodes_modes, n_trees);
_multiply_update_array(nodes_splits, n_trees);
_multiply_update_array(leaf_weights, n_trees);
}
auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes");
auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits");
auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight");
// add attributes
test.AddAttribute("n_targets", n_targets);
test.AddAttribute("aggregate_function", aggregate_function);
test.AddAttribute("post_transform", post_transform);
test.AddAttribute("tree_roots", tree_roots);
test.AddAttribute("nodes_modes", nodes_modes_as_tensor);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_splits", nodes_splits_as_tensor);
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_trueleafs", nodes_trueleafs);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_falseleafs", nodes_falseleafs);
test.AddAttribute("leaf_targetids", leaf_targetids);
test.AddAttribute("leaf_weights", leaf_weights_as_tensor);
// fill input data
test.AddInput<T>("X", {3, 2}, X);
test.AddOutput<T>("Y", {3, 2}, Y);
test.Run();
}
template <typename T>
void GenTreeAndRunTestWithSetMembership(const std::vector<T>& X, const std::vector<T>& Y, const int64_t& aggregate_function, int n_trees = 1) {
OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain);
int64_t n_targets = 4;
int64_t post_transform = 0;
std::vector<int64_t> tree_roots = {0};
std::vector<int64_t> nodes_featureids = {0, 0, 0};
std::vector<int64_t> nodes_truenodeids = {1, 0, 1};
std::vector<int64_t> nodes_trueleafs = {0, 1, 1};
std::vector<int64_t> nodes_falsenodeids = {2, 2, 3};
std::vector<int64_t> nodes_falseleafs = {1, 0, 1};
std::vector<int64_t> leaf_targetids = {0, 1, 2, 3};
std::vector<uint8_t> nodes_modes = {0, 6, 6};
std::vector<T> nodes_splits = {11.f, 232344.f, NAN};
std::vector<T> membership_values = {1.2f, 3.7f, 8.f, 9.f, NAN, 12.f, 7.f, NAN};
std::vector<T> leaf_weights = {1.f, 10.f, 1000.f, 100.f};
if (n_trees > 1) {
// Multiplies the number of trees to test the parallelization by trees.
_multiply_update_array(tree_roots, n_trees, (int64_t)nodes_truenodeids.size());
_multiply_update_array(nodes_featureids, n_trees);
_multiply_update_childnode(nodes_truenodeids, nodes_trueleafs, nodes_falseleafs, n_trees);
_multiply_update_childnode(nodes_falsenodeids, nodes_falseleafs, nodes_trueleafs, n_trees);
_multiply_update_array(nodes_trueleafs, n_trees);
_multiply_update_array(nodes_falseleafs, n_trees);
_multiply_update_array(leaf_targetids, n_trees);
_multiply_update_array(nodes_modes, n_trees);
_multiply_update_array(nodes_splits, n_trees);
_multiply_update_array(membership_values, n_trees);
_multiply_update_array(leaf_weights, n_trees);
}
auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes");
auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits");
auto membership_values_as_tensor = make_tensor(membership_values, "membership_values");
auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight");
// add attributes
test.AddAttribute("n_targets", n_targets);
test.AddAttribute("aggregate_function", aggregate_function);
test.AddAttribute("post_transform", post_transform);
test.AddAttribute("tree_roots", tree_roots);
test.AddAttribute("nodes_modes", nodes_modes_as_tensor);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_splits", nodes_splits_as_tensor);
test.AddAttribute("membership_values", membership_values_as_tensor);
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_trueleafs", nodes_trueleafs);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_falseleafs", nodes_falseleafs);
test.AddAttribute("leaf_targetids", leaf_targetids);
test.AddAttribute("leaf_weights", leaf_weights_as_tensor);
// fill input data
test.AddInput<T>("X", {6, 1}, X);
test.AddOutput<T>("Y", {6, 4}, Y);
test.Run();
}
TEST(MLOpTest, TreeEnsembleFloat) {
std::vector<float> X = {1.2f, 3.4f, -0.12f, 1.66f, 4.14f, 1.77f};
std::vector<float> Y = {5.23f, 0.f, 5.23f, 0.f, 0.f, 12.12f};
GenTreeAndRunTest<float>(X, Y, 1, 1);
Y = {15.69f, 0.f, 15.69f, 0.f, 0.f, 36.36f};
GenTreeAndRunTest<float>(X, Y, 1, 3);
}
TEST(MLOpTest, TreeEnsembleDouble) {
std::vector<double> X = {1.2f, 3.4f, -0.12f, 1.66f, 4.14f, 1.77f};
std::vector<double> Y = {5.23f, 0.f, 5.23f, 0.f, 0.f, 12.12f};
GenTreeAndRunTest<double>(X, Y, 1, 1);
_multiply_arrays_values(Y, 3);
GenTreeAndRunTest<double>(X, Y, 1, 3);
}
TEST(MLOpTest, TreeEnsembleSetMembership) {
std::vector<double> X = {1.2f, 3.4f, -0.12f, NAN, 12.0f, 7.0f};
std::vector<double> Y = {
1.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 100.f,
0.f, 0.f, 0.f, 100.f,
0.f, 0.f, 1000.f, 0.f,
0.f, 0.f, 1000.f, 0.f,
0.f, 10.f, 0.f, 0.f};
GenTreeAndRunTestWithSetMembership<double>(X, Y, 1, 1);
_multiply_arrays_values(Y, 5);
GenTreeAndRunTestWithSetMembership<double>(X, Y, 1, 5);
}
TEST(MLOpTest, TreeEnsembleLeafOnly) {
OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain);
int64_t n_targets = 1;
int64_t aggregate_function = 1;
int64_t post_transform = 0;
std::vector<int64_t> tree_roots = {0};
std::vector<uint8_t> nodes_modes = {0};
std::vector<int64_t> nodes_featureids = {0};
std::vector<double> nodes_splits = {0.f};
std::vector<int64_t> nodes_truenodeids = {0};
std::vector<int64_t> nodes_trueleafs = {1};
std::vector<int64_t> nodes_falsenodeids = {0};
std::vector<int64_t> nodes_falseleafs = {1};
std::vector<int64_t> leaf_targetids = {0};
std::vector<double> leaf_weights = {6.23f};
auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes");
auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits");
auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight");
// add attributes
test.AddAttribute("n_targets", n_targets);
test.AddAttribute("aggregate_function", aggregate_function);
test.AddAttribute("post_transform", post_transform);
test.AddAttribute("tree_roots", tree_roots);
test.AddAttribute("nodes_modes", nodes_modes_as_tensor);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_splits", nodes_splits_as_tensor);
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_trueleafs", nodes_trueleafs);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_falseleafs", nodes_falseleafs);
test.AddAttribute("leaf_targetids", leaf_targetids);
test.AddAttribute("leaf_weights", leaf_weights_as_tensor);
// fill input data
std::vector<double> X = {1.f, 4.f};
std::vector<double> Y = {6.23f, 6.23f};
test.AddInput<double>("X", {2, 1}, X);
test.AddOutput<double>("Y", {2, 1}, Y);
test.Run();
}
} // namespace test
} // namespace onnxruntime

Просмотреть файл

@ -679,6 +679,90 @@ TEST(MLOpTest, TreeRegressorSingleTargetSum_as_tensor_precision) {
GenTreeAndRunTest1_as_tensor_precision(3);
}
TEST(MLOpTest, TreeRegressorCategoricals) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);
// tree
int64_t n_targets = 1;
std::vector<int64_t> nodes_featureids = {0, 0, 0, 0, 1, 0, 0};
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"};
std::vector<float> nodes_values = {1, 3, 4, 0, 5.5, 0, 0};
std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0};
std::vector<int64_t> nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0};
std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 0, 0};
std::vector<int64_t> target_nodeids = {3, 5, 6};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727};
// add attributes
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_treeids", nodes_treeids);
test.AddAttribute("nodes_nodeids", nodes_nodeids);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_values", nodes_values);
test.AddAttribute("nodes_modes", nodes_modes);
test.AddAttribute("target_treeids", target_treeids);
test.AddAttribute("target_nodeids", target_nodeids);
test.AddAttribute("target_ids", target_ids);
test.AddAttribute("target_weights", target_weights);
test.AddAttribute("n_targets", n_targets);
// fill input data
std::vector<float> X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f};
std::vector<float> Y = {17.700000762939453, 11.100000381469727, -4.699999809265137};
test.AddInput<float>("X", {3, 2}, X);
test.AddOutput<float>("Y", {3, 1}, Y);
test.Run();
}
TEST(MLOpTest, TreeRegressorCategoricalsFolding) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);
// tree
int64_t n_targets = 1;
std::vector<int64_t> nodes_featureids = {0, 0, 1, 1, 0, 0, 0};
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "LEAF", "LEAF"};
std::vector<float> nodes_values = {1, 3, 2, 3, 0, 0, 0};
std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 4, 0, 0, 0};
std::vector<int64_t> nodes_truenodeids = {5, 5, 6, 6, 0, 0, 0};
std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 0, 0};
std::vector<int64_t> target_nodeids = {4, 5, 6};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {17.700000762939453, 11.100000381469727, -4.699999809265137};
// add attributes
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_treeids", nodes_treeids);
test.AddAttribute("nodes_nodeids", nodes_nodeids);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_values", nodes_values);
test.AddAttribute("nodes_modes", nodes_modes);
test.AddAttribute("target_treeids", target_treeids);
test.AddAttribute("target_nodeids", target_nodeids);
test.AddAttribute("target_ids", target_ids);
test.AddAttribute("target_weights", target_weights);
test.AddAttribute("n_targets", n_targets);
// fill input data
std::vector<float> X = {1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f};
std::vector<float> Y = {11.100000381469727, 11.100000381469727, -4.699999809265137, 17.700000762939453};
test.AddInput<float>("X", {4, 2}, X);
test.AddOutput<float>("Y", {4, 1}, Y);
test.Run();
}
TEST(MLOpTest, TreeRegressorTrueNodeBeforeNode) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);