Skip to content

Commit

Permalink
Support multi-target, fit intercept for hinge. (#9850)
Browse files Browse the repository at this point in the history
trivialfis authored Dec 7, 2023
1 parent 39c637e commit 42de920
Showing 8 changed files with 218 additions and 152 deletions.
54 changes: 36 additions & 18 deletions src/common/linalg_op.cuh
Original file line number Diff line number Diff line change
@@ -1,31 +1,48 @@
/*!
* Copyright 2021-2022 by XGBoost Contributors
/**
* Copyright 2021-2023, XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_LINALG_OP_CUH_
#define XGBOOST_COMMON_LINALG_OP_CUH_

#include "device_helpers.cuh"
#include <cstdint> // for int32_t
#include <cstdlib> // for size_t
#include <tuple> // for apply

#include "device_helpers.cuh" // for LaunchN
#include "linalg_op.h"
#include "xgboost/context.h"
#include "xgboost/linalg.h"
#include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for TensorView

namespace xgboost {
namespace linalg {
template <typename T, int32_t D, typename Fn>
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
"For function with return, use transform instead.");
if (t.Contiguous()) {
auto ptr = t.Values().data();
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { fn(i, ptr[i]); });
} else {
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
fn(i, v);
namespace cuda_impl {
// Use template specialization to dispatch, Windows + CUDA 11.8 doesn't support extended
// lambda inside constexpr if
template <typename T, std::int32_t D>
struct ElementWiseImpl {
template <typename Fn>
void operator()(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s) {
static_assert(D > 1);
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) mutable {
std::apply(fn, linalg::UnravelIndex(i, t.Shape()));
});
}
};

template <typename T>
struct ElementWiseImpl<T, 1> {
template <typename Fn>
void operator()(linalg::TensorView<T, 1> t, Fn&& fn, cudaStream_t s) {
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) { fn(i); });
}
};

template <typename T, std::int32_t D, typename Fn>
void ElementWiseKernel(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
cuda_impl::ElementWiseImpl<T, D>{}(t, fn, s);
}
} // namespace cuda_impl

template <typename T, int32_t D, typename Fn>
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
@@ -42,7 +59,8 @@ void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_

template <typename T, int32_t D, typename Fn>
void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
ctx->IsCUDA() ? ElementWiseKernelDevice(t, fn) : ElementWiseKernelHost(t, ctx->Threads(), fn);
ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn)
: ElementWiseKernelHost(t, ctx->Threads(), fn);
}
} // namespace linalg
} // namespace xgboost
30 changes: 18 additions & 12 deletions src/common/linalg_op.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2021-2022 by XGBoost Contributors
/**
* Copyright 2021-2023, XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_LINALG_OP_H_
#define XGBOOST_COMMON_LINALG_OP_H_
@@ -27,17 +27,23 @@ void ElementWiseTransformHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&
}
}

template <typename T, int32_t D, typename Fn>
void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& fn) {
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
"For function with return, use transform instead.");
if (t.Contiguous()) {
auto ptr = t.Values().data();
common::ParallelFor(t.Size(), n_threads, [&](size_t i) { fn(i, ptr[i]); });
template <typename T, std::int32_t D, typename Fn>
void ElementWiseKernelHost(linalg::TensorView<T, D> t, std::int32_t n_threads, Fn &&fn) {
if constexpr (D == 1) {
common::ParallelFor(t.Size(), n_threads, [&](std::size_t i) { fn(i); });
} else if (D == 2 && t.CContiguous() && t.Shape(0) > t.Shape(1) * 64) {
// Heuristic. Tall, c-contiguous matrix,
auto n_rows = t.Shape(0);
auto n_columns = t.Shape(1);
common::ParallelFor(n_rows, n_threads, [&](std::size_t i) {
for (std::size_t j = 0; j < n_columns; ++j) {
fn(i, j);
}
});
} else {
common::ParallelFor(t.Size(), n_threads, [&](size_t i) {
auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
fn(i, v);
common::ParallelFor(t.Size(), n_threads, [&](std::size_t i) {
auto idx = linalg::UnravelIndex(i, t.Shape());
std::apply(fn, idx);
});
}
}
120 changes: 66 additions & 54 deletions src/objective/hinge.cu
Original file line number Diff line number Diff line change
@@ -4,92 +4,104 @@
* \brief Provides an implementation of the hinge loss function
* \author Henry Gouk
*/
#include "xgboost/objective.h"
#include "xgboost/json.h"
#include "xgboost/span.h"
#include "xgboost/host_device_vector.h"
#include <algorithm> // for max
#include <cstddef> // for size_t
#include <cstdint> // for int32_t

#include "../common/math.h"
#include "../common/transform.h"
#include "../common/common.h"
#include "../common/common.h" // for Range
#if defined(XGBOOST_USE_CUDA)
#include "../common/linalg_op.cuh"
#endif
#include "../common/linalg_op.h"
#include "../common/optional_weight.h" // for OptionalWeights
#include "../common/transform.h" // for Transform
#include "init_estimation.h" // for FitIntercept
#include "xgboost/data.h" // for MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/json.h" // for Json
#include "xgboost/linalg.h" // for UnravelIndex
#include "xgboost/span.h" // for Span

namespace xgboost::obj {

#if defined(XGBOOST_USE_CUDA)
DMLC_REGISTRY_FILE_TAG(hinge_obj_gpu);
#endif // defined(XGBOOST_USE_CUDA)

class HingeObj : public ObjFunction {
class HingeObj : public FitIntercept {
public:
HingeObj() = default;

void Configure(Args const&) override {}
void Configure(Args const &) override {}
ObjInfo Task() const override { return ObjInfo::kRegression; }

void GetGradient(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
[[nodiscard]] bst_target_t Targets(MetaInfo const &info) const override {
// Multi-target regression.
return std::max(static_cast<std::size_t>(1), info.labels.Shape(1));
}

void GetGradient(HostDeviceVector<float> const &preds, MetaInfo const &info,
std::int32_t /*iter*/, linalg::Matrix<GradientPair> *out_gpair) override {
CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels.Size())
<< "labels are not correctly provided"
<< "preds.size=" << preds.Size()
<< ", label.size=" << info.labels.Size();

const size_t ndata = preds.Size();
const bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
CheckInitInputs(info);
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), info.num_row_)
<< "Number of weights should be equal to number of data points.";
}
CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target for `binary:hinge` is not yet supported.";
out_gpair->Reshape(ndata, 1);
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels,
common::Span<const bst_float> _weights) {
bst_float p = _preds[_idx];
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float y = _labels[_idx] * 2.0 - 1.0;
bst_float g, h;
if (p * y < 1.0) {
g = -y * w;
h = w;
} else {
g = 0.0;
h = std::numeric_limits<bst_float>::min();
}
_out_gpair[_idx] = GradientPair(g, h);
},
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(),
ctx_->Device()).Eval(
out_gpair->Data(), &preds, info.labels.Data(), &info.weights_);

bst_target_t n_targets = this->Targets(info);
out_gpair->Reshape(info.num_row_, n_targets);
auto gpair = out_gpair->View(ctx_->Device());

preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, n_targets);

auto labels = info.labels.View(ctx_->Device());

info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()};

linalg::ElementWiseKernel(this->ctx_, labels,
[=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
auto w = weight[i];

auto p = predt(i, j);
auto y = labels(i, j) * 2.0 - 1.0;

float g, h;
if (p * y < 1.0) {
g = -y * w;
h = w;
} else {
g = 0.0;
h = std::numeric_limits<float>::min();
}
gpair(i, j) = GradientPair{g, h};
});
}

void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
void PredTransform(HostDeviceVector<float> *io_preds) const override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
[] XGBOOST_DEVICE(std::size_t _idx, common::Span<float> _preds) {
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
},
common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, this->ctx_->Threads(),
io_preds->Device())
.Eval(io_preds);
}

[[nodiscard]] const char* DefaultEvalMetric() const override {
return "error";
}
[[nodiscard]] const char *DefaultEvalMetric() const override { return "error"; }

void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
void SaveConfig(Json *p_out) const override {
auto &out = *p_out;
out["name"] = String("binary:hinge");
}
void LoadConfig(Json const &) override {}
};

// register the objective functions
XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge")
.describe("Hinge loss. Expects labels to be in [0,1f]")
.set_body([]() { return new HingeObj(); });
.describe("Hinge loss. Expects labels to be in [0,1f]")
.set_body([]() { return new HingeObj(); });

} // namespace xgboost::obj
33 changes: 15 additions & 18 deletions src/objective/quantile_obj.cu
Original file line number Diff line number Diff line change
@@ -75,28 +75,25 @@ class QuantileRegression : public ObjFunction {
: info.weights_.ConstHostSpan()};

preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeVec(&preds);
auto n_samples = info.num_row_;
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, n_targets);

alpha_.SetDevice(ctx_->Device());
auto alpha = ctx_->IsCUDA() ? alpha_.ConstDeviceSpan() : alpha_.ConstHostSpan();

linalg::ElementWiseKernel(
ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable {
auto [sample_id, quantile_id, target_id] =
linalg::UnravelIndex(i, n_samples, alpha.size(), n_targets / alpha.size());
assert(target_id == 0);

auto d = predt(i) - labels(sample_id, target_id);
auto h = weight[sample_id];
if (d >= 0) {
auto g = (1.0f - alpha[quantile_id]) * weight[sample_id];
gpair(sample_id, quantile_id) = GradientPair{g, h};
} else {
auto g = (-alpha[quantile_id] * weight[sample_id]);
gpair(sample_id, quantile_id) = GradientPair{g, h};
}
});
linalg::ElementWiseKernel(ctx_, gpair,
[=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
// j is the quantile index
// 0 is the target index
auto d = predt(i, j) - labels(i, 0);
auto h = weight[i];
if (d >= 0) {
auto g = (1.0f - alpha[j]) * weight[i];
gpair(i, j) = GradientPair{g, h};
} else {
auto g = (-alpha[j] * weight[i]);
gpair(i, j) = GradientPair{g, h};
}
});
}

void InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const override {
43 changes: 22 additions & 21 deletions src/objective/regression_obj.cu
Original file line number Diff line number Diff line change
@@ -255,24 +255,24 @@ class PseudoHuberRegression : public FitIntercept {
auto gpair = out_gpair->View(ctx_->Device());

preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeVec(&preds);
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info));

info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()};

linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable {
auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape()));
const float z = predt(i) - y;
const float scale_sqrt = std::sqrt(1 + common::Sqr(z) / common::Sqr(slope));
float grad = z / scale_sqrt;
linalg::ElementWiseKernel(
ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
float z = predt(i, j) - labels(i, j);
float scale_sqrt = std::sqrt(1 + common::Sqr(z) / common::Sqr(slope));
float grad = z / scale_sqrt;

auto scale = common::Sqr(slope) + common::Sqr(z);
float hess = common::Sqr(slope) / (scale * scale_sqrt);
auto scale = common::Sqr(slope) + common::Sqr(z);
float hess = common::Sqr(slope) / (scale * scale_sqrt);

auto w = weight[sample_id];
gpair(i) = {grad * w, hess * w};
});
auto w = weight[i];
gpair(i) = {grad * w, hess * w};
});
}

[[nodiscard]] const char* DefaultEvalMetric() const override { return "mphe"; }
@@ -635,20 +635,21 @@ class MeanAbsoluteError : public ObjFunction {
auto gpair = out_gpair->View(ctx_->Device());

preds.SetDevice(ctx_->Device());
auto predt = linalg::MakeVec(&preds);
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info));
info.weights_.SetDevice(ctx_->Device());
common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan()
: info.weights_.ConstHostSpan()};

linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, float y) mutable {
auto sign = [](auto x) {
return (x > static_cast<decltype(x)>(0)) - (x < static_cast<decltype(x)>(0));
};
auto [sample_id, target_id] = linalg::UnravelIndex(i, labels.Shape());
auto grad = sign(predt(i) - y) * weight[sample_id];
auto hess = weight[sample_id];
gpair(sample_id, target_id) = GradientPair{grad, hess};
});
linalg::ElementWiseKernel(
ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
auto sign = [](auto x) {
return (x > static_cast<decltype(x)>(0)) - (x < static_cast<decltype(x)>(0));
};
auto y = labels(i, j);
auto hess = weight[i];
auto grad = sign(predt(i, j) - y) * hess;
gpair(i, j) = GradientPair{grad, hess};
});
}

void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_margin) const override {
Loading

0 comments on commit 42de920

Please sign in to comment.