Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement the NDCG metric. #8906

Merged
merged 2 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/xgboost/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,26 @@ class DMatrixCache {
}
return container_.at(key).value;
}
/**
* \brief Re-initialize the item in cache.
*
* Since the shared_ptr is used to hold the item, any reference that lives outside of
* the cache can no-longer be reached from the cache.
*
* We use reset instead of erase to avoid walking through the whole cache for renewing
* a single item. (the cache is FIFO, needs to maintain the order).
*/
template <typename... Args>
std::shared_ptr<CacheT> ResetItem(std::shared_ptr<DMatrix> m, Args const&... args) {
std::lock_guard<std::mutex> guard{lock_};
CheckConsistent();
auto key = Key{m.get(), std::this_thread::get_id()};
auto it = container_.find(key);
CHECK(it != container_.cend());
it->second = {m, std::make_shared<CacheT>(args...)};
CheckConsistent();
return it->second.value;
}
/**
* \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning.
Expand Down
240 changes: 186 additions & 54 deletions src/metric/rank_metric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,51 @@
// corresponding headers that brings in those function declaration can't be included with CUDA).
// This precludes the CPU and GPU logic to coexist inside a .cu file

#include <dmlc/registry.h>
#include <xgboost/metric.h>
#include "rank_metric.h"

#include <cmath>
#include <vector>
#include <dmlc/omp.h>
#include <dmlc/registry.h>

#include "../collective/communicator-inl.h"
#include "../common/algorithm.h" // Sort
#include "../common/math.h"
#include "../common/ranking_utils.h" // MakeMetricName
#include "../common/threading_utils.h"
#include "metric_common.h"
#include "xgboost/host_device_vector.h"
#include <algorithm> // for stable_sort, copy, fill_n, min, max
#include <array> // for array
#include <cmath> // for log, sqrt
#include <cstddef> // for size_t, std
#include <cstdint> // for uint32_t
#include <functional> // for less, greater
#include <map> // for operator!=, _Rb_tree_const_iterator
#include <memory> // for allocator, unique_ptr, shared_ptr, __shared_...
#include <numeric> // for accumulate
#include <ostream> // for operator<<, basic_ostream, ostringstream
#include <string> // for char_traits, operator<, basic_string, to_string
#include <utility> // for pair, make_pair
#include <vector> // for vector

#include "../collective/communicator-inl.h" // for IsDistributed, Allreduce
#include "../collective/communicator.h" // for Operation
#include "../common/algorithm.h" // for ArgSort, Sort
#include "../common/linalg_op.h" // for cbegin, cend
#include "../common/math.h" // for CmpFirst
#include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights
#include "../common/ranking_utils.h" // for LambdaRankParam, NDCGCache, ParseMetricName
#include "../common/threading_utils.h" // for ParallelFor
#include "../common/transform_iterator.h" // for IndexTransformIter
#include "dmlc/common.h" // for OMPException
#include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult
#include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args
#include "xgboost/cache.h" // for DMatrixCache
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo, DMatrix
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/json.h" // for Json, FromJson, IsA, ToJson, get, Null, Object
#include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT...
#include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ
#include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric
#include "xgboost/span.h" // for Span, operator!=
#include "xgboost/string_view.h" // for StringView

namespace {

using PredIndPair = std::pair<xgboost::bst_float, uint32_t>;
using PredIndPair = std::pair<xgboost::bst_float, xgboost::ltr::rel_degree_t>;
using PredIndPairContainer = std::vector<PredIndPair>;

/*
Expand Down Expand Up @@ -87,8 +115,7 @@ class PerGroupWeightPolicy {

} // anonymous namespace

namespace xgboost {
namespace metric {
namespace xgboost::metric {
// tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(rank_metric);

Expand Down Expand Up @@ -257,40 +284,6 @@ struct EvalPrecision : public EvalRank {
}
};

/*! \brief NDCG: Normalized Discounted Cumulative Gain at N */
struct EvalNDCG : public EvalRank {
private:
double CalcDCG(const PredIndPairContainer &rec) const {
double sumdcg = 0.0;
for (size_t i = 0; i < rec.size() && i < this->topn; ++i) {
const unsigned rel = rec[i].second;
if (rel != 0) {
sumdcg += ((1 << rel) - 1) / std::log2(i + 2.0);
}
}
return sumdcg;
}

public:
explicit EvalNDCG(const char* name, const char* param) : EvalRank(name, param) {}

double EvalGroup(PredIndPairContainer *recptr) const override {
PredIndPairContainer &rec(*recptr);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
double dcg = CalcDCG(rec);
std::stable_sort(rec.begin(), rec.end(), common::CmpSecond);
double idcg = CalcDCG(rec);
if (idcg == 0.0f) {
if (this->minus) {
return 0.0f;
} else {
return 1.0f;
}
}
return dcg/idcg;
}
};

/*! \brief Mean Average Precision at N, for both classification and rank */
struct EvalMAP : public EvalRank {
public:
Expand Down Expand Up @@ -377,16 +370,155 @@ XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision("pre", param); });

XGBOOST_REGISTER_METRIC(NDCG, "ndcg")
.describe("ndcg@k for rank.")
.set_body([](const char* param) { return new EvalNDCG("ndcg", param); });

XGBOOST_REGISTER_METRIC(MAP, "map")
.describe("map@k for rank.")
.set_body([](const char* param) { return new EvalMAP("map", param); });

XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
.describe("Negative log partial likelihood of Cox proportional hazards model.")
.set_body([](const char*) { return new EvalCox(); });
} // namespace metric
} // namespace xgboost

// ranking metrics that requires cache
template <typename Cache>
class EvalRankWithCache : public Metric {
protected:
ltr::LambdaRankParam param_;
bool minus_{false};
std::string name_;

DMatrixCache<Cache> cache_{DMatrixCache<Cache>::DefaultSize()};

public:
EvalRankWithCache(StringView name, const char* param) {
auto constexpr kMax = ltr::LambdaRankParam::NotSet();
std::uint32_t topn{kMax};
this->name_ = ltr::ParseMetricName(name, param, &topn, &minus_);
if (topn != kMax) {
param_.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", std::to_string(topn)},
{"lambdarank_pair_method", "topk"}});
}
param_.UpdateAllowUnknown(Args{});
}
void Configure(Args const&) override {
// do not configure, otherwise the ndcg param will be forced into the same as the one in
// objective.
}
void LoadConfig(Json const& in) override {
if (IsA<Null>(in)) {
return;
}
auto const& obj = get<Object const>(in);
auto it = obj.find("lambdarank_param");
if (it != obj.cend()) {
FromJson(it->second, &param_);
}
}

void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String{this->Name()};
out["lambdarank_param"] = ToJson(param_);
}

double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override {
auto const& info = p_fmat->Info();
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
if (p_cache->Param() != param_) {
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
}
CHECK(p_cache->Param() == param_);
CHECK_EQ(preds.Size(), info.labels.Size());

return this->Eval(preds, info, p_cache);
}

virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
std::shared_ptr<Cache> p_cache) = 0;
};

namespace {
double Finalize(double score, double sw) {
std::array<double, 2> dat{score, sw};
collective::Allreduce<collective::Operation::kSum>(dat.data(), dat.size());
if (sw > 0.0) {
score = score / sw;
}

CHECK_LE(score, 1.0 + kRtEps)
<< "Invalid output score, might be caused by invalid query group weight.";
score = std::min(1.0, score);

return score;
}
} // namespace

/**
* \brief Implement the NDCG score function for learning to rank.
*
* Ties are ignored, which can lead to different result with other implementations.
*/
class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
public:
using EvalRankWithCache::EvalRankWithCache;
const char* Name() const override { return name_.c_str(); }

double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
std::shared_ptr<ltr::NDCGCache> p_cache) override {
if (ctx_->IsCUDA()) {
auto ndcg = cuda_impl::NDCGScore(ctx_, info, preds, minus_, p_cache);
return Finalize(ndcg.Residue(), ndcg.Weights());
}

// group local ndcg
auto group_ptr = p_cache->DataGroupPtr(ctx_);
bst_group_t n_groups = group_ptr.size() - 1;
auto ndcg_gloc = p_cache->Dcg(ctx_);
std::fill_n(ndcg_gloc.Values().data(), ndcg_gloc.Size(), 0.0);

auto h_inv_idcg = p_cache->InvIDCG(ctx_);
auto p_discount = p_cache->Discount(ctx_).data();

auto h_label = info.labels.HostView();
auto h_predt = linalg::MakeTensorView(ctx_, &preds, preds.Size());
auto weights = common::MakeOptionalWeights(ctx_, info.weights_);

common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) {
auto g_predt = h_predt.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1]));
auto g_labels = h_label.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1]), 0);
auto sorted_idx = common::ArgSort<std::size_t>(ctx_, linalg::cbegin(g_predt),
linalg::cend(g_predt), std::greater<>{});
double ndcg{.0};
double inv_idcg = h_inv_idcg(g);
if (inv_idcg <= 0.0) {
ndcg_gloc(g) = minus_ ? 0.0 : 1.0;
return;
}
std::size_t n{std::min(sorted_idx.size(), static_cast<std::size_t>(param_.TopK()))};
if (param_.ndcg_exp_gain) {
for (std::size_t i = 0; i < n; ++i) {
ndcg += p_discount[i] * ltr::CalcDCGGain(g_labels(sorted_idx[i])) * inv_idcg;
}
} else {
for (std::size_t i = 0; i < n; ++i) {
ndcg += p_discount[i] * g_labels(sorted_idx[i]) * inv_idcg;
}
}
ndcg_gloc(g) += ndcg * weights[g];
});
double sum_w{0};
if (weights.Empty()) {
sum_w = n_groups;
} else {
sum_w = std::accumulate(weights.weights.cbegin(), weights.weights.cend(), 0.0);
}
auto ndcg = std::accumulate(linalg::cbegin(ndcg_gloc), linalg::cend(ndcg_gloc), 0.0);
return Finalize(ndcg, sum_w);
}
};

XGBOOST_REGISTER_METRIC(EvalNDCG, "ndcg")
.describe("ndcg@k for ranking.")
.set_body([](char const* param) {
return new EvalNDCG{"ndcg", param};
});
} // namespace xgboost::metric
Loading