Skip to content

Commit

Permalink
Implementation of TreeEnsemble ai.onnx.ml==5 (#22333)
Browse files Browse the repository at this point in the history
### Description
Merges PR #21851, #21222.

Implements TreeEnsemble from ai.onnx.ml==5 (CPU).

---------

Co-authored-by: Bilyana Indzheva <[email protected]>
Co-authored-by: Bilyana Indzheva <[email protected]>
Co-authored-by: Christian Bourjau <[email protected]>
  • Loading branch information
4 people authored Nov 22, 2024
1 parent c97dd6e commit a2ba3cb
Show file tree
Hide file tree
Showing 13 changed files with 1,155 additions and 349 deletions.
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ Do not modify directly.*
|SVMClassifier|*in* X:**T1**<br> *out* Y:**T2**<br> *out* Z:**tensor(float)**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int64), tensor(string)|
|SVMRegressor|*in* X:**T**<br> *out* Y:**tensor(float)**|1+|**T** = tensor(float)|
|Scaler|*in* X:**T**<br> *out* Y:**tensor(float)**|1+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|TreeEnsemble|*in* X:**T**<br> *out* Y:**T**|5+|**T** = tensor(double), tensor(float)|
|TreeEnsembleClassifier|*in* X:**T1**<br> *out* Y:**T2**<br> *out* Z:**tensor(float)**|3+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int64), tensor(string)|
|||[1, 2]|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int64), tensor(string)|
|TreeEnsembleRegressor|*in* X:**T**<br> *out* Y:**tensor(float)**|3+|**T** = tensor(double), tensor(float)|
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2925,6 +2925,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, int32_t, TreeEnsembleClassifier);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleRegressor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleRegressor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, float, TreeEnsemble);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, double, TreeEnsemble);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string, LabelEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_float, LabelEncoder);
Expand Down Expand Up @@ -3043,6 +3045,10 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
TreeEnsembleRegressor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double,
TreeEnsembleRegressor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, float,
TreeEnsemble)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, double,
TreeEnsemble)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string,
LabelEncoder)>,
Expand Down
58 changes: 31 additions & 27 deletions onnxruntime/core/providers/cpu/ml/ml_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,44 +20,48 @@ enum class OUTPUT_MODE {
ALL_SCORES
};

enum NODE_MODE : uint8_t {
LEAF = 1,
BRANCH_LEQ = 2,
BRANCH_LT = 4,
BRANCH_GTE = 6,
BRANCH_GT = 8,
BRANCH_EQ = 10,
BRANCH_NEQ = 12
enum NODE_MODE_ONNX : uint8_t {
BRANCH_LEQ = 0,
BRANCH_LT = 1,
BRANCH_GTE = 2,
BRANCH_GT = 3,
BRANCH_EQ = 4,
BRANCH_NEQ = 5,
BRANCH_MEMBER = 6,
LEAF = 7,
};

static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {
static inline NODE_MODE_ONNX MakeTreeNodeMode(const std::string& input) {
if (input == "BRANCH_LEQ") {
return NODE_MODE::BRANCH_LEQ;
return NODE_MODE_ONNX::BRANCH_LEQ;
}
if (input == "LEAF") {
return NODE_MODE::LEAF;
return NODE_MODE_ONNX::LEAF;
}
if (input == "BRANCH_LT") {
return NODE_MODE::BRANCH_LT;
return NODE_MODE_ONNX::BRANCH_LT;
}
if (input == "BRANCH_GTE") {
return NODE_MODE::BRANCH_GTE;
return NODE_MODE_ONNX::BRANCH_GTE;
}
if (input == "BRANCH_GT") {
return NODE_MODE::BRANCH_GT;
return NODE_MODE_ONNX::BRANCH_GT;
}
if (input == "BRANCH_EQ") {
return NODE_MODE::BRANCH_EQ;
return NODE_MODE_ONNX::BRANCH_EQ;
}
return NODE_MODE::BRANCH_NEQ;
if (input == "BRANCH_MEMBER") {
return NODE_MODE_ONNX::BRANCH_MEMBER;
}
return NODE_MODE_ONNX::BRANCH_NEQ;
}

enum class POST_EVAL_TRANSFORM {
NONE,
LOGISTIC,
SOFTMAX,
SOFTMAX_ZERO,
PROBIT
enum class POST_EVAL_TRANSFORM : int64_t {
NONE = 0,
LOGISTIC = 1,
SOFTMAX = 2,
SOFTMAX_ZERO = 3,
PROBIT = 4
};

static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) {
Expand All @@ -76,11 +80,11 @@ static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) {
return POST_EVAL_TRANSFORM::PROBIT;
}

enum class AGGREGATE_FUNCTION {
AVERAGE,
SUM,
MIN,
MAX
enum class AGGREGATE_FUNCTION : int64_t {
AVERAGE = 0,
SUM = 1,
MIN = 2,
MAX = 3
};

static inline AGGREGATE_FUNCTION MakeAggregateFunction(const std::string& input) {
Expand Down
59 changes: 59 additions & 0 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cpu/ml/tree_ensemble.h"
#include "core/providers/cpu/ml/tree_ensemble_helper.h"
#include "core/common/inlined_containers_fwd.h"

namespace onnxruntime {
namespace ml {

ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
TreeEnsemble,
5,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).MayInplace(0, 0),
TreeEnsemble<float>);

ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
TreeEnsemble,
5,
double,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()).MayInplace(0, 0),
TreeEnsemble<double>);

template <typename T>
TreeEnsemble<T>::TreeEnsemble(const OpKernelInfo& info) : OpKernel(info) {
if constexpr (std::is_same<T, double>::value) {
p_tree_ensemble_ = std::make_unique<detail::TreeEnsembleCommonV5<T, double>>();
} else {
p_tree_ensemble_ = std::make_unique<detail::TreeEnsembleCommonV5<T, float>>();
}
ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info));
}

template <typename T>
Status TreeEnsemble<T>::GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const {
InlinedVector<std::string> names{
"leaf_targetids", "leaf_weights", "membership_values", "nodes_falseleafs",
"nodes_falsenodeids", "nodes_featureids", "nodes_hitrates", "nodes_missing_value_tracks_true",
"nodes_modes", "nodes_splits", "nodes_trueleafs", "nodes_truenodeids"};
removable_attributes.swap(names);
return Status::OK();
}

template <typename T>
common::Status TreeEnsemble<T>::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
if (X->Shape().NumDimensions() == 0) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Input shape needs to be at least a single dimension.");
}
int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0];
Tensor* Y = context->Output(0, {N, p_tree_ensemble_->get_target_or_class_count()});
return p_tree_ensemble_->compute(context, X, Y, NULL);
}

} // namespace ml
} // namespace onnxruntime
25 changes: 25 additions & 0 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "tree_ensemble_common.h"

namespace onnxruntime {
namespace ml {
template <typename T>
class TreeEnsemble final : public OpKernel {
typedef T InputType; // input type
typedef float OutputType; // output type
public:
explicit TreeEnsemble(const OpKernelInfo& info);
common::Status Compute(OpKernelContext* context) const override;
Status GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const override;

private:
// Pointer on one instance of
// detail::TreeEnsembleCommonV5<T, ThresholdType>
// where ThresholdType is defined after accessing the attributes.
std::unique_ptr<detail::TreeEnsembleCommonAttributes> p_tree_ensemble_;
};
} // namespace ml
} // namespace onnxruntime
40 changes: 37 additions & 3 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,40 @@ union PtrOrWeight {
} weight_data;
};

enum NODE_MODE_ORT : uint8_t {
LEAF = 1,
BRANCH_LEQ = 2,
BRANCH_LT = 4,
BRANCH_GTE = 6,
BRANCH_GT = 8,
BRANCH_EQ = 10,
BRANCH_NEQ = 12,
BRANCH_MEMBER = 14,
};

inline NODE_MODE_ORT Convert_NODE_MODE_ONNX_to_ORT(NODE_MODE_ONNX node_mode) {
switch (node_mode) {
case NODE_MODE_ONNX::LEAF:
return NODE_MODE_ORT::LEAF;
case NODE_MODE_ONNX::BRANCH_LEQ:
return NODE_MODE_ORT::BRANCH_LEQ;
case NODE_MODE_ONNX::BRANCH_LT:
return NODE_MODE_ORT::BRANCH_LT;
case NODE_MODE_ONNX::BRANCH_GTE:
return NODE_MODE_ORT::BRANCH_GTE;
case NODE_MODE_ONNX::BRANCH_GT:
return NODE_MODE_ORT::BRANCH_GT;
case NODE_MODE_ONNX::BRANCH_EQ:
return NODE_MODE_ORT::BRANCH_EQ;
case NODE_MODE_ONNX::BRANCH_NEQ:
return NODE_MODE_ORT::BRANCH_NEQ;
case NODE_MODE_ONNX::BRANCH_MEMBER:
return NODE_MODE_ORT::BRANCH_MEMBER;
default:
ORT_THROW("Unexpected value for node_mode");
};
}

template <typename T>
struct TreeNodeElement {
int feature_id;
Expand All @@ -98,10 +132,10 @@ struct TreeNodeElement {
// weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also
// stored in `value_or_unique_weight`.
PtrOrWeight<T> truenode_or_weight;
uint8_t flags;
NODE_MODE_ORT flags;

inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); }
inline bool is_not_leaf() const { return !(flags & NODE_MODE::LEAF); }
inline NODE_MODE_ORT mode() const { return NODE_MODE_ORT(flags & 0xF); }
inline bool is_not_leaf() const { return !(flags & NODE_MODE_ORT::LEAF); }
inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; }
};

Expand Down
Loading

0 comments on commit a2ba3cb

Please sign in to comment.