Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Set Membership in TreeEnsemble #21222

Closed
wants to merge 11 commits into from
Closed
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/cpu/ml/ml_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ enum NODE_MODE : uint8_t {
BRANCH_GTE = 6,
BRANCH_GT = 8,
BRANCH_EQ = 10,
BRANCH_NEQ = 12
BRANCH_NEQ = 12,
BRANCH_SM = 14
};

static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {
Expand All @@ -49,6 +50,9 @@ static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {
if (input == "BRANCH_EQ") {
return NODE_MODE::BRANCH_EQ;
}
if (input == "BRANCH_SM") {
return NODE_MODE::BRANCH_SM;
}
return NODE_MODE::BRANCH_NEQ;
}

Expand Down
148 changes: 126 additions & 22 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "core/platform/threadpool.h"
#include "tree_ensemble_helper.h"

#include <algorithm>

namespace onnxruntime {
namespace ml {
namespace detail {
Expand Down Expand Up @@ -87,11 +89,17 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const;

private:
bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE>& cmodes,
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<float>& target_class_weights, const std::vector<ThresholdType>& target_class_weights_as_tensor,
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices);
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids);
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids, const std::vector<float>& target_class_weights,
const std::vector<ThresholdType>& target_class_weights_as_tensor, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices);
};

template <typename InputType, typename ThresholdType, typename OutputType>
Expand Down Expand Up @@ -270,6 +278,16 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
}
}

// Sort targets
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
indices.reserve(target_class_nodeids.size());
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
indices.emplace_back(
std::pair<TreeNodeElementId, uint32_t>(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i));
}

std::sort(indices.begin(), indices.end());

// Let's construct nodes_ such that the false branch is always the next element in nodes_.
// updated_mapping will translates the old position of each node to the new node position in nodes_.
std::vector<size_t> updated_mapping(nodes_treeids.size(), 0);
Expand All @@ -280,26 +298,13 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
int64_t tree_id = node_tree_ids[i].tree_id;
size_t root_position =
AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values,
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
roots_.push_back(&nodes_[root_position]);
previous_tree_id = tree_id;
}
}

n_trees_ = roots_.size();
if (((int64_t)nodes_.size()) != n_nodes_) {
ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ").");
}

// Sort targets
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
indices.reserve(target_class_nodeids.size());
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
indices.emplace_back(
std::pair<TreeNodeElementId, uint32_t>(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i));
}

std::sort(indices.begin(), indices.end());

TreeNodeElementId ind;
SparseValue<ThresholdType> w;
Expand Down Expand Up @@ -341,13 +346,59 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
return Status::OK();
}

template <typename InputType, typename ThresholdType, typename OutputType>
bool TreeEnsembleCommon<InputType, ThresholdType, OutputType>::CheckIfSubtreesAreEqual(
const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE>& cmodes,
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<float>& target_class_weights, const std::vector<ThresholdType>& target_class_weights_as_tensor,
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices) {
// Leaves have values set at 0
if (cmodes[left_id] != cmodes[right_id] || nodes_featureids[left_id] != nodes_featureids[right_id] || (!nodes_values_as_tensor.empty() && nodes_values_as_tensor[left_id] != nodes_values_as_tensor[right_id]) || (nodes_values_as_tensor.empty() && node_values[left_id] != node_values[right_id])) {
return false;
}

if (cmodes[left_id] == NODE_MODE::LEAF) {
auto left_tree_node = node_tree_ids[left_id];
auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(left_tree_node, uint32_t(0)))->second;

auto right_tree_node = node_tree_ids[right_id];
auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(right_tree_node, uint32_t(0)))->second;

if (target_class_weights_as_tensor.empty()) {
return target_class_weights[left_target_node] == target_class_weights[right_target_node];
} else {
return target_class_weights_as_tensor[left_target_node] == target_class_weights_as_tensor[right_target_node];
}
}

return CheckIfSubtreesAreEqual(falsenode_ids[left_id], falsenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids,
nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices) &&
CheckIfSubtreesAreEqual(truenode_ids[left_id], truenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids,
nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices);
}

inline void UpdateThreshold(double val, double& mask) {
uint64_t new_mask = *reinterpret_cast<uint64_t*>(&mask) | (1ll << (static_cast<uint32_t>(val) - 1));
mask = *reinterpret_cast<double*>(&new_mask);
}

inline void UpdateThreshold(float val, float& mask) {
uint32_t new_mask = *reinterpret_cast<uint32_t*>(&mask) | (1 << (static_cast<uint32_t>(val) - 1));
mask = *reinterpret_cast<float*>(&new_mask);
}

#define BITCOUNT(T) int64_t(sizeof(T) * 8)
#define CANMASK(v, T) (v >= 1 && v <= BITCOUNT(T))

template <typename InputType, typename ThresholdType, typename OutputType>
bili2002 marked this conversation as resolved.
Show resolved Hide resolved
size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id,
const InlinedVector<TreeNodeElementId>& node_tree_ids) {
const InlinedVector<TreeNodeElementId>& node_tree_ids, const std::vector<float>& target_class_weights,
const std::vector<ThresholdType>& target_class_weights_as_tensor, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices) {
bili2002 marked this conversation as resolved.
Show resolved Hide resolved
// Validate this index maps to the same tree_id as the one we should be building.
if (node_tree_ids[i].tree_id != tree_id) {
ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i);
Expand All @@ -369,23 +420,47 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
if (node.feature_id > max_feature_id_) {
max_feature_id_ = node.feature_id;
}
node.value_or_unique_weight =
nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];

node.value_or_unique_weight = 0;
const auto node_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
if (node.flags == NODE_MODE::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) {
UpdateThreshold(node_threshold, node.value_or_unique_weight);
node.flags = NODE_MODE::BRANCH_SM;
} else {
node.value_or_unique_weight = node_threshold;
}

if (i < static_cast<size_t>(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) {
node.flags |= static_cast<uint8_t>(MissingTrack::kTrue);
}
nodes_.push_back(std::move(node));
if (nodes_[node_pos].is_not_leaf()) {
auto falsenode_id = falsenode_ids[i];
if (nodes_[node_pos].flags == NODE_MODE::BRANCH_SM) {
auto falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];

while (cmodes[falsenode_id] == NODE_MODE::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] &&
CANMASK(falsenode_threshold, ThresholdType) &&
CheckIfSubtreesAreEqual(truenode_ids[i], truenode_ids[falsenode_id], tree_id, cmodes, truenode_ids, falsenode_ids,
nodes_featureids, nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices)) {
UpdateThreshold(falsenode_threshold, nodes_[node_pos].value_or_unique_weight);
falsenode_id = falsenode_ids[falsenode_id];
falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];
}
}

size_t false_branch =
AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
AddNodes(falsenode_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
if (false_branch != node_pos + 1) {
ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ",
static_cast<int>(nodes_[node_pos].flags));
}
size_t true_branch =
AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
// We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_.
// nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch];
nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch];
Expand Down Expand Up @@ -684,6 +759,16 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
} \
}

inline bool SetMembershipCheck(double val, double mask) {
auto val_as_int = static_cast<int64_t>(val);
return CANMASK(val_as_int, double) && (((1ll << (val_as_int - 1)) & *reinterpret_cast<uint64_t*>(&mask)) != 0);
}

inline bool SetMembershipCheck(float val, float mask) {
auto val_as_int = static_cast<int64_t>(val);
return CANMASK(val_as_int, float) && (((1ll << (val_as_int - 1)) & *reinterpret_cast<uint32_t*>(&mask)) != 0);
}

inline bool _isnan_(float x) { return std::isnan(x); }
inline bool _isnan_(double x) { return std::isnan(x); }
inline bool _isnan_(int64_t) { return false; }
Expand Down Expand Up @@ -726,6 +811,20 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
case NODE_MODE::BRANCH_NEQ:
TREE_FIND_VALUE(!=)
break;
case NODE_MODE::BRANCH_SM:
if (has_missing_tracks_) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
}
} else {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = SetMembershipCheck(val, root->value_or_unique_weight) ? root->truenode_or_weight.ptr : root + 1;
}
}
case NODE_MODE::LEAF:
break;
}
Expand Down Expand Up @@ -759,6 +858,11 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_SM:
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::LEAF:
return root;
}
Expand Down
84 changes: 84 additions & 0 deletions onnxruntime/test/providers/cpu/ml/treeregressor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,90 @@ TEST(MLOpTest, TreeRegressorSingleTargetSum_as_tensor_precision) {
GenTreeAndRunTest1_as_tensor_precision(3);
}

TEST(MLOpTest, TreeRegressorCategoricals) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

// tree
int64_t n_targets = 1;
std::vector<int64_t> nodes_featureids = {0, 0, 0, 0, 1, 0, 0};
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"};
bili2002 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<float> nodes_values = {1, 3, 4, 0, 5.5, 0, 0};

std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0};
std::vector<int64_t> nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0};

std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 0, 0};
std::vector<int64_t> target_nodeids = {3, 5, 6};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727};

// add attributes
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_treeids", nodes_treeids);
test.AddAttribute("nodes_nodeids", nodes_nodeids);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_values", nodes_values);
test.AddAttribute("nodes_modes", nodes_modes);
test.AddAttribute("target_treeids", target_treeids);
test.AddAttribute("target_nodeids", target_nodeids);
test.AddAttribute("target_ids", target_ids);
test.AddAttribute("target_weights", target_weights);
test.AddAttribute("n_targets", n_targets);

// fill input data
std::vector<float> X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f};
std::vector<float> Y = {17.700000762939453, 11.100000381469727, -4.699999809265137};
test.AddInput<float>("X", {3, 2}, X);
test.AddOutput<float>("Y", {3, 1}, Y);
test.Run();
}

TEST(MLOpTest, TreeRegressorCategoricalsFolding) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

// tree
int64_t n_targets = 1;
std::vector<int64_t> nodes_featureids = {0, 0, 1, 1, 0, 0, 0};
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "LEAF", "LEAF"};
std::vector<float> nodes_values = {1, 3, 2, 3, 0, 0, 0};

std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 4, 0, 0, 0};
std::vector<int64_t> nodes_truenodeids = {5, 5, 6, 6, 0, 0, 0};

std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 0, 0};
std::vector<int64_t> target_nodeids = {4, 5, 6};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {17.700000762939453, 11.100000381469727, -4.699999809265137};

// add attributes
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_treeids", nodes_treeids);
test.AddAttribute("nodes_nodeids", nodes_nodeids);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_values", nodes_values);
test.AddAttribute("nodes_modes", nodes_modes);
test.AddAttribute("target_treeids", target_treeids);
test.AddAttribute("target_nodeids", target_nodeids);
test.AddAttribute("target_ids", target_ids);
test.AddAttribute("target_weights", target_weights);
test.AddAttribute("n_targets", n_targets);

// fill input data
std::vector<float> X = {1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f};
std::vector<float> Y = {11.100000381469727, 11.100000381469727, -4.699999809265137, 17.700000762939453};
test.AddInput<float>("X", {4, 2}, X);
test.AddOutput<float>("Y", {4, 1}, Y);
test.Run();
}

TEST(MLOpTest, TreeRegressorTrueNodeBeforeNode) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

Expand Down
Loading