Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[multi] Implement weight feature importance. #10700

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions src/gbm/gbtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,11 @@ class GBTree : public GradientBooster {
auto add_score = [&](auto fn) {
for (auto idx : trees) {
CHECK_LE(idx, total_n_trees) << "Invalid tree index.";
auto const& p_tree = model_.trees[idx];
p_tree->WalkTree([&](bst_node_t nidx) {
auto const& node = (*p_tree)[nidx];
if (!node.IsLeaf()) {
split_counts[node.SplitIndex()]++;
fn(p_tree, nidx, node.SplitIndex());
auto const& tree = *model_.trees[idx];
tree.WalkTree([&](bst_node_t nidx) {
if (!tree.IsLeaf(nidx)) {
split_counts[tree.SplitIndex(nidx)]++;
fn(tree, nidx, tree.SplitIndex(nidx));
}
return true;
});
Expand All @@ -253,12 +252,18 @@ class GBTree : public GradientBooster {
gain_map[split] = split_counts[split];
});
} else if (importance_type == "gain" || importance_type == "total_gain") {
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
gain_map[split] += p_tree->Stat(nidx).loss_chg;
if (!model_.trees.empty() && model_.trees.front()->IsMultiTarget()) {
LOG(FATAL) << "gain/total_gain " << MTNotImplemented();
}
add_score([&](auto const& tree, bst_node_t nidx, bst_feature_t split) {
gain_map[split] += tree.Stat(nidx).loss_chg;
});
} else if (importance_type == "cover" || importance_type == "total_cover") {
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
gain_map[split] += p_tree->Stat(nidx).sum_hess;
if (!model_.trees.empty() && model_.trees.front()->IsMultiTarget()) {
LOG(FATAL) << "cover/total_cover " << MTNotImplemented();
}
add_score([&](auto const& tree, bst_node_t nidx, bst_feature_t split) {
gain_map[split] += tree.Stat(nidx).sum_hess;
});
} else {
LOG(FATAL)
Expand Down
30 changes: 30 additions & 0 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,36 @@ def test_feature_importances_weight():
cls.feature_importances_


def test_feature_importances_weight_vector_leaf() -> None:
from sklearn.datasets import make_multilabel_classification

X, y = make_multilabel_classification(random_state=1994)
with pytest.raises(ValueError, match="gain/total_gain"):
clf = xgb.XGBClassifier(multi_strategy="multi_output_tree")
clf.fit(X, y)
clf.feature_importances_

with pytest.raises(ValueError, match="cover/total_cover"):
clf = xgb.XGBClassifier(
multi_strategy="multi_output_tree", importance_type="cover"
)
clf.fit(X, y)
clf.feature_importances_

clf = xgb.XGBClassifier(
multi_strategy="multi_output_tree",
importance_type="weight",
colsample_bynode=0.2,
)
clf.fit(X, y, feature_weights=np.arange(0, X.shape[1]))
fi = clf.feature_importances_
assert fi[0] == 0.0
assert fi[-1] > fi[1] * 5

w = np.polynomial.Polynomial.fit(np.arange(0, X.shape[1]), fi, deg=1)
assert w.coef[1] > 0.03


@pytest.mark.skipif(**tm.no_pandas())
def test_feature_importances_gain():
from sklearn.datasets import load_digits
Expand Down
Loading