Skip to content

Commit

Permalink
fix node_mode
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Nov 19, 2024
1 parent dbfd50d commit 8855e2b
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 52 deletions.
20 changes: 10 additions & 10 deletions onnxruntime/core/providers/cpu/ml/ml_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ enum class OUTPUT_MODE {
ALL_SCORES
};

enum NODE_MODE : uint8_t {
enum NODE_MODE_ONNX : uint8_t {
BRANCH_LEQ = 0,
BRANCH_LT = 1,
BRANCH_GTE = 2,
Expand All @@ -31,29 +31,29 @@ enum NODE_MODE : uint8_t {
LEAF = 7,
};

static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {
static inline NODE_MODE_ONNX MakeTreeNodeMode(const std::string& input) {
if (input == "BRANCH_LEQ") {
return NODE_MODE::BRANCH_LEQ;
return NODE_MODE_ONNX::BRANCH_LEQ;
}
if (input == "LEAF") {
return NODE_MODE::LEAF;
return NODE_MODE_ONNX::LEAF;
}
if (input == "BRANCH_LT") {
return NODE_MODE::BRANCH_LT;
return NODE_MODE_ONNX::BRANCH_LT;
}
if (input == "BRANCH_GTE") {
return NODE_MODE::BRANCH_GTE;
return NODE_MODE_ONNX::BRANCH_GTE;
}
if (input == "BRANCH_GT") {
return NODE_MODE::BRANCH_GT;
return NODE_MODE_ONNX::BRANCH_GT;
}
if (input == "BRANCH_EQ") {
return NODE_MODE::BRANCH_EQ;
return NODE_MODE_ONNX::BRANCH_EQ;
}
if (input == "BRANCH_MEMBER") {
return NODE_MODE::BRANCH_MEMBER;
return NODE_MODE_ONNX::BRANCH_MEMBER;
}
return NODE_MODE::BRANCH_NEQ;
return NODE_MODE_ONNX::BRANCH_NEQ;
}

enum class POST_EVAL_TRANSFORM : int64_t {
Expand Down
40 changes: 37 additions & 3 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,40 @@ union PtrOrWeight {
} weight_data;
};

enum NODE_MODE_ORT : uint8_t {
LEAF = 1,
BRANCH_LEQ = 2,
BRANCH_LT = 4,
BRANCH_GTE = 6,
BRANCH_GT = 8,
BRANCH_EQ = 10,
BRANCH_NEQ = 12,
BRANCH_MEMBER = 14,
};

inline NODE_MODE_ORT Convert_NODE_MODE_ONNX_to_ORT(NODE_MODE_ONNX node_mode) {
switch (node_mode) {
case NODE_MODE_ONNX::LEAF:
return NODE_MODE_ORT::LEAF;
case NODE_MODE_ONNX::BRANCH_LEQ:
return NODE_MODE_ORT::BRANCH_LEQ;
case NODE_MODE_ONNX::BRANCH_LT:
return NODE_MODE_ORT::BRANCH_LT;
case NODE_MODE_ONNX::BRANCH_GTE:
return NODE_MODE_ORT::BRANCH_GTE;
case NODE_MODE_ONNX::BRANCH_GT:
return NODE_MODE_ORT::BRANCH_GT;
case NODE_MODE_ONNX::BRANCH_EQ:
return NODE_MODE_ORT::BRANCH_EQ;
case NODE_MODE_ONNX::BRANCH_NEQ:
return NODE_MODE_ORT::BRANCH_NEQ;
case NODE_MODE_ONNX::BRANCH_MEMBER:
return NODE_MODE_ORT::BRANCH_MEMBER;
default:
ORT_THROW("Unexpected value for node_mode");
};

Check warning on line 112 in onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h:112: You don't need a ; after a } [readability/braces] [4]
}

template <typename T>
struct TreeNodeElement {
int feature_id;
Expand All @@ -98,10 +132,10 @@ struct TreeNodeElement {
// weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also
// stored in `value_or_unique_weight`.
PtrOrWeight<T> truenode_or_weight;
uint8_t flags;
NODE_MODE_ORT flags;

inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); }
inline bool is_not_leaf() const { return !(flags & NODE_MODE::LEAF); }
inline NODE_MODE_ORT mode() const { return NODE_MODE_ORT(flags & 0xF); }
inline bool is_not_leaf() const { return !(flags & NODE_MODE_ORT::LEAF); }
inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; }
};

Expand Down
23 changes: 14 additions & 9 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ struct TreeEnsembleAttributesV3 {
std::vector<float> nodes_hitrates;
std::vector<ThresholdType> nodes_hitrates_as_tensor;
std::vector<int64_t> nodes_missing_value_tracks_true;
std::vector<NODE_MODE> nodes_modes;
std::vector<NODE_MODE_ONNX> nodes_modes;
std::vector<int64_t> nodes_nodeids;
std::vector<int64_t> nodes_treeids;
std::vector<int64_t> nodes_truenodeids;
Expand All @@ -118,11 +118,16 @@ struct TreeEnsembleAttributesV5 {
TreeEnsembleAttributesV5() {}
TreeEnsembleAttributesV5(const OpKernelInfo& info) {

Check warning on line 119 in onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h:119: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
#if !defined(ORT_MINIMAL_BUILD)
std::vector<uint8_t> nodes_modes_i;
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "leaf_weights", leaf_weights));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "membership_values", membership_values));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates", nodes_hitrates));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_modes", nodes_modes));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_modes", nodes_modes_i));
ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_splits", nodes_splits));
nodes_modes.reserve(nodes_modes.size());
for (auto i : nodes_modes_i) {
nodes_modes.push_back(static_cast<NODE_MODE_ONNX>(i));
}
#else
// GetVectorAttrsOrDefault is not part of the minimal build.
// As a result, TreeEnsemble v5 cannot be available in this build.
Expand Down Expand Up @@ -162,7 +167,7 @@ struct TreeEnsembleAttributesV5 {
std::vector<int64_t> nodes_featureids;
std::vector<ThresholdType> nodes_hitrates;
std::vector<int64_t> nodes_missing_value_tracks_true;
std::vector<uint8_t> nodes_modes;
std::vector<NODE_MODE_ONNX> nodes_modes;
std::vector<ThresholdType> nodes_splits;
std::vector<int64_t> nodes_trueleafs;
std::vector<int64_t> nodes_truenodeids;
Expand All @@ -180,7 +185,7 @@ struct TreeEnsembleAttributesV5 {
size_t curr_id = 0;
for (const auto node_mode : nodes_modes) {
membership_values_by_id.emplace_back();
if (node_mode != static_cast<int64_t>(NODE_MODE::BRANCH_MEMBER)) {
if (node_mode != NODE_MODE_ONNX::BRANCH_MEMBER) {
continue;
}

Expand Down Expand Up @@ -231,7 +236,7 @@ struct TreeEnsembleAttributesV5 {
output.nodes_treeids.push_back(curr_treeid);

if (is_leaf) {
output.nodes_modes.push_back(NODE_MODE::LEAF);
output.nodes_modes.push_back(NODE_MODE_ONNX::LEAF);
output.target_class_ids.push_back(leaf_targetids[curr_id]);
output.target_class_nodeids.push_back(curr_nodeid);
output.target_class_treeids.push_back(curr_treeid);
Expand Down Expand Up @@ -261,11 +266,11 @@ struct TreeEnsembleAttributesV5 {
}

// unroll `BRANCH_MEMBER` to a chain of `BRANCH_EQ`
if (nodes_modes[curr_id] == static_cast<uint8_t>(NODE_MODE::BRANCH_MEMBER)) {
output.nodes_modes.push_back(NODE_MODE::BRANCH_EQ);
if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER) {
output.nodes_modes.push_back(NODE_MODE_ONNX::BRANCH_EQ);
output.nodes_values_as_tensor.push_back(membership_values_by_id[curr_id][curr_membership_value_id]);
} else {
output.nodes_modes.push_back(static_cast<NODE_MODE>(nodes_modes[curr_id]));
output.nodes_modes.push_back(nodes_modes[curr_id]);
output.nodes_values_as_tensor.push_back(nodes_splits[curr_id]);
}

Expand All @@ -286,7 +291,7 @@ struct TreeEnsembleAttributesV5 {
// so in that case we are only moving the pointer for `membership_values`
//
// otherwise, the `falsenode_id` is pointing to the real falsenode subtree
if (nodes_modes[curr_id] == static_cast<uint8_t>(NODE_MODE::BRANCH_MEMBER) &&
if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER &&
curr_membership_value_id + 1 < membership_values_by_id[curr_id].size()) {
false_nodeid = transformInputOneTree(curr_id, curr_treeid, false_nodeid, curr_membership_value_id + 1, false,
membership_values_by_id, output);
Expand Down
60 changes: 30 additions & 30 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const;

private:
bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE>& cmodes,
bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE_ONNX>& cmodes,
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, gsl::span<const int64_t> nodes_featureids,
gsl::span<const ThresholdType> nodes_values_as_tensor, gsl::span<const float> node_values,
gsl::span<const float> target_class_weights, gsl::span<const ThresholdType> target_class_weights_as_tensor,
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices);
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE_ONNX>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, gsl::span<const int64_t> nodes_featureids,
gsl::span<const ThresholdType> nodes_values_as_tensor, gsl::span<const float> node_values,
gsl::span<const int64_t> nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,
Expand Down Expand Up @@ -148,13 +148,13 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
// Additional members
size_t limit;
uint32_t i;
InlinedVector<NODE_MODE> cmodes;
InlinedVector<NODE_MODE_ONNX> cmodes;
cmodes.reserve(attributes.nodes_modes.size());
same_mode_ = true;
int fpos = -1;
for (i = 0, limit = attributes.nodes_modes.size(); i < limit; ++i) {
cmodes.push_back(attributes.nodes_modes[i]);
if (cmodes[i] == NODE_MODE::LEAF) continue;
if (cmodes[i] == NODE_MODE_ONNX::LEAF) continue;
if (fpos == -1) {
fpos = static_cast<int>(i);
continue;
Expand Down Expand Up @@ -189,7 +189,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(

TreeNodeElementId coor;
for (i = 0; i < limit; ++i) {
if (cmodes[i] == NODE_MODE::LEAF) {
if (cmodes[i] == NODE_MODE_ONNX::LEAF) {
truenode_ids.push_back(0);
falsenode_ids.push_back(0);
} else {
Expand Down Expand Up @@ -293,7 +293,7 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(

template <typename InputType, typename ThresholdType, typename OutputType>
bool TreeEnsembleCommon<InputType, ThresholdType, OutputType>::CheckIfSubtreesAreEqual(
const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE>& cmodes,
const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE_ONNX>& cmodes,
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, gsl::span<const int64_t> nodes_featureids,
gsl::span<const ThresholdType> nodes_values_as_tensor, gsl::span<const float> node_values,
gsl::span<const float> target_class_weights, gsl::span<const ThresholdType> target_class_weights_as_tensor,
Expand All @@ -305,7 +305,7 @@ bool TreeEnsembleCommon<InputType, ThresholdType, OutputType>::CheckIfSubtreesAr
return false;
}

if (cmodes[left_id] == NODE_MODE::LEAF) {
if (cmodes[left_id] == NODE_MODE_ONNX::LEAF) {
const auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[left_id], uint32_t(0)))->second;
const auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[right_id], uint32_t(0)))->second;

Expand Down Expand Up @@ -337,7 +337,7 @@ inline void UpdateThreshold(float val, float& mask) {

template <typename InputType, typename ThresholdType, typename OutputType>
size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const size_t i, const InlinedVector<NODE_MODE_ONNX>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, gsl::span<const int64_t> nodes_featureids,
gsl::span<const ThresholdType> nodes_values_as_tensor, gsl::span<const float> node_values,
gsl::span<const int64_t> nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id,
Expand All @@ -359,23 +359,23 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
updated_mapping[i] = node_pos;

TreeNodeElement<ThresholdType> node;
node.flags = static_cast<uint8_t>(cmodes[i]);
node.flags = Convert_NODE_MODE_ONNX_to_ORT(cmodes[i]);
node.feature_id = static_cast<int>(nodes_featureids[i]);
if (node.feature_id > max_feature_id_) {
max_feature_id_ = node.feature_id;
}

node.value_or_unique_weight = 0;
const ThresholdType node_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
if (node.flags == NODE_MODE::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) {
if (node.flags == NODE_MODE_ORT::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) {
UpdateThreshold(node_threshold, node.value_or_unique_weight);
node.flags = NODE_MODE::BRANCH_MEMBER;
node.flags = NODE_MODE_ORT::BRANCH_MEMBER;
} else {
node.value_or_unique_weight = node_threshold;
}

if (i < static_cast<size_t>(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) {
node.flags |= static_cast<uint8_t>(MissingTrack::kTrue);
node.flags = static_cast<NODE_MODE_ORT>(static_cast<uint8_t>(node.flags) | static_cast<uint8_t>(MissingTrack::kTrue));
}
nodes_.push_back(std::move(node));
if (nodes_[node_pos].is_not_leaf()) {
Expand All @@ -387,10 +387,10 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
// Afterwards, when looking whether a feature is included we can do an `and` with the mask of the node
// and the one of the feature (the mask has only one bit set on the place for its value)
// Beware that if a category is bigger than the threshold type, the node stays as `EQ` and no combination is done
if (nodes_[node_pos].flags == NODE_MODE::BRANCH_MEMBER) {
if (nodes_[node_pos].flags == NODE_MODE_ORT::BRANCH_MEMBER) {
ThresholdType falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];

while (cmodes[falsenode_id] == NODE_MODE::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] &&
while (cmodes[falsenode_id] == NODE_MODE_ORT::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] &&

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Vcpkg

comparison of different enumeration types ('ValueType<allocator<NODE_MODE_ONNX>>' (aka 'onnxruntime::ml::NODE_MODE_ONNX') and 'onnxruntime::ml::detail::NODE_MODE_ORT') is deprecated [-Wdeprecated-enum-compare]
CANMASK(falsenode_threshold, ThresholdType) &&
CheckIfSubtreesAreEqual(truenode_ids[i], truenode_ids[falsenode_id], tree_id, cmodes, truenode_ids, falsenode_ids,
nodes_featureids, nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices)) {
Expand Down Expand Up @@ -724,7 +724,7 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
InputType val;
if (same_mode_) {
switch (root->mode()) {
case NODE_MODE::BRANCH_LEQ:
case NODE_MODE_ORT::BRANCH_LEQ:
if (has_missing_tracks_) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
Expand All @@ -739,22 +739,22 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
}
}
break;
case NODE_MODE::BRANCH_LT:
case NODE_MODE_ORT::BRANCH_LT:
TREE_FIND_VALUE(<)
break;
case NODE_MODE::BRANCH_GTE:
case NODE_MODE_ORT::BRANCH_GTE:
TREE_FIND_VALUE(>=)
break;
case NODE_MODE::BRANCH_GT:
case NODE_MODE_ORT::BRANCH_GT:
TREE_FIND_VALUE(>)
break;
case NODE_MODE::BRANCH_EQ:
case NODE_MODE_ORT::BRANCH_EQ:
TREE_FIND_VALUE(==)
break;
case NODE_MODE::BRANCH_NEQ:
case NODE_MODE_ORT::BRANCH_NEQ:
TREE_FIND_VALUE(!=)
break;
case NODE_MODE::BRANCH_MEMBER:
case NODE_MODE_ORT::BRANCH_MEMBER:
if (has_missing_tracks_) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
Expand All @@ -768,7 +768,7 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
root = SetMembershipCheck(val, root->value_or_unique_weight) ? root->truenode_or_weight.ptr : root + 1;
}
}
case NODE_MODE::LEAF:
case NODE_MODE_ORT::LEAF:

Check warning

Code scanning / PREfast

Unannotated fallthrough between switch labels (es.78). Warning

Unannotated fallthrough between switch labels (es.78).

Check warning

Code scanning / PREfast

Unannotated fallthrough between switch labels (es.78). Warning

Unannotated fallthrough between switch labels (es.78).

Check warning

Code scanning / PREfast

Unannotated fallthrough between switch labels (es.78). Warning

Unannotated fallthrough between switch labels (es.78).

Check warning

Code scanning / PREfast

Unannotated fallthrough between switch labels (es.78). Warning

Unannotated fallthrough between switch labels (es.78).

Check warning

Code scanning / PREfast

Unannotated fallthrough between switch labels (es.78). Warning

Unannotated fallthrough between switch labels (es.78).
break;
}
} else { // Different rules to compare to node thresholds.
Expand All @@ -777,36 +777,36 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
val = x_data[root->feature_id];
threshold = root->value_or_unique_weight;
switch (root->mode()) {
case NODE_MODE::BRANCH_LEQ:
case NODE_MODE_ORT::BRANCH_LEQ:
root = val <= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_LT:
case NODE_MODE_ORT::BRANCH_LT:
root = val < threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_GTE:
case NODE_MODE_ORT::BRANCH_GTE:
root = val >= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_GT:
case NODE_MODE_ORT::BRANCH_GT:
root = val > threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_EQ:
case NODE_MODE_ORT::BRANCH_EQ:
root = val == threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_NEQ:
case NODE_MODE_ORT::BRANCH_NEQ:
root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_MEMBER:
case NODE_MODE_ORT::BRANCH_MEMBER:
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::LEAF:
case NODE_MODE_ORT::LEAF:
return root;
}
}
Expand Down

0 comments on commit 8855e2b

Please sign in to comment.