diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index 361f926e9892..9ef36598d9de 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -1,31 +1,48 @@ -/*! - * Copyright 2021-2022 by XGBoost Contributors +/** + * Copyright 2021-2023, XGBoost Contributors */ #ifndef XGBOOST_COMMON_LINALG_OP_CUH_ #define XGBOOST_COMMON_LINALG_OP_CUH_ -#include "device_helpers.cuh" +#include // for int32_t +#include // for size_t +#include // for apply + +#include "device_helpers.cuh" // for LaunchN #include "linalg_op.h" -#include "xgboost/context.h" -#include "xgboost/linalg.h" +#include "xgboost/context.h" // for Context +#include "xgboost/linalg.h" // for TensorView namespace xgboost { namespace linalg { -template -void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { - dh::safe_cuda(cudaSetDevice(t.Device().ordinal)); - static_assert(std::is_void>::value, - "For function with return, use transform instead."); - if (t.Contiguous()) { - auto ptr = t.Values().data(); - dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { fn(i, ptr[i]); }); - } else { - dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { - T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); - fn(i, v); +namespace cuda_impl { +// Use template specialization to dispatch, Windows + CUDA 11.8 doesn't support extended +// lambda inside constexpr if +template +struct ElementWiseImpl { + template + void operator()(linalg::TensorView t, Fn&& fn, cudaStream_t s) { + static_assert(D > 1); + dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) mutable { + std::apply(fn, linalg::UnravelIndex(i, t.Shape())); }); } +}; + +template +struct ElementWiseImpl { + template + void operator()(linalg::TensorView t, Fn&& fn, cudaStream_t s) { + dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) { fn(i); }); + } +}; + +template +void ElementWiseKernel(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { + dh::safe_cuda(cudaSetDevice(t.Device().ordinal)); + cuda_impl::ElementWiseImpl{}(t, fn, s); } +} // namespace cuda_impl template void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { @@ -42,7 +59,8 @@ void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, cudaStream_ template void ElementWiseKernel(Context const* ctx, linalg::TensorView t, Fn&& fn) { - ctx->IsCUDA() ? ElementWiseKernelDevice(t, fn) : ElementWiseKernelHost(t, ctx->Threads(), fn); + ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn) + : ElementWiseKernelHost(t, ctx->Threads(), fn); } } // namespace linalg } // namespace xgboost diff --git a/src/common/linalg_op.h b/src/common/linalg_op.h index d89e5a736b6e..52141164f785 100644 --- a/src/common/linalg_op.h +++ b/src/common/linalg_op.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2021-2022 by XGBoost Contributors +/** + * Copyright 2021-2023, XGBoost Contributors */ #ifndef XGBOOST_COMMON_LINALG_OP_H_ #define XGBOOST_COMMON_LINALG_OP_H_ @@ -27,17 +27,23 @@ void ElementWiseTransformHost(linalg::TensorView t, int32_t n_threads, Fn& } } -template -void ElementWiseKernelHost(linalg::TensorView t, int32_t n_threads, Fn&& fn) { - static_assert(std::is_void>::value, - "For function with return, use transform instead."); - if (t.Contiguous()) { - auto ptr = t.Values().data(); - common::ParallelFor(t.Size(), n_threads, [&](size_t i) { fn(i, ptr[i]); }); +template +void ElementWiseKernelHost(linalg::TensorView t, std::int32_t n_threads, Fn &&fn) { + if constexpr (D == 1) { + common::ParallelFor(t.Size(), n_threads, [&](std::size_t i) { fn(i); }); + } else if (D == 2 && t.CContiguous() && t.Shape(0) > t.Shape(1) * 64) { + // Heuristic. Tall, c-contiguous matrix, + auto n_rows = t.Shape(0); + auto n_columns = t.Shape(1); + common::ParallelFor(n_rows, n_threads, [&](std::size_t i) { + for (std::size_t j = 0; j < n_columns; ++j) { + fn(i, j); + } + }); } else { - common::ParallelFor(t.Size(), n_threads, [&](size_t i) { - auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); - fn(i, v); + common::ParallelFor(t.Size(), n_threads, [&](std::size_t i) { + auto idx = linalg::UnravelIndex(i, t.Shape()); + std::apply(fn, idx); }); } } diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index dd9a19b13a07..fb51f325bc94 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -4,71 +4,85 @@ * \brief Provides an implementation of the hinge loss function * \author Henry Gouk */ -#include "xgboost/objective.h" -#include "xgboost/json.h" -#include "xgboost/span.h" -#include "xgboost/host_device_vector.h" +#include // for max +#include // for size_t +#include // for int32_t -#include "../common/math.h" -#include "../common/transform.h" -#include "../common/common.h" +#include "../common/common.h" // for Range +#if defined(XGBOOST_USE_CUDA) +#include "../common/linalg_op.cuh" +#endif +#include "../common/linalg_op.h" +#include "../common/optional_weight.h" // for OptionalWeights +#include "../common/transform.h" // for Transform +#include "init_estimation.h" // for FitIntercept +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // HostDeviceVector +#include "xgboost/json.h" // for Json +#include "xgboost/linalg.h" // for UnravelIndex +#include "xgboost/span.h" // for Span namespace xgboost::obj { - #if defined(XGBOOST_USE_CUDA) DMLC_REGISTRY_FILE_TAG(hinge_obj_gpu); #endif // defined(XGBOOST_USE_CUDA) -class HingeObj : public ObjFunction { +class HingeObj : public FitIntercept { public: HingeObj() = default; - void Configure(Args const&) override {} + void Configure(Args const &) override {} ObjInfo Task() const override { return ObjInfo::kRegression; } - void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, + [[nodiscard]] bst_target_t Targets(MetaInfo const &info) const override { + // Multi-target regression. + return std::max(static_cast(1), info.labels.Shape(1)); + } + + void GetGradient(HostDeviceVector const &preds, MetaInfo const &info, std::int32_t /*iter*/, linalg::Matrix *out_gpair) override { - CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.Size(), info.labels.Size()) - << "labels are not correctly provided" - << "preds.size=" << preds.Size() - << ", label.size=" << info.labels.Size(); - - const size_t ndata = preds.Size(); - const bool is_null_weight = info.weights_.Size() == 0; - if (!is_null_weight) { - CHECK_EQ(info.weights_.Size(), ndata) + CheckInitInputs(info); + CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels."; + if (!info.weights_.Empty()) { + CHECK_EQ(info.weights_.Size(), info.num_row_) << "Number of weights should be equal to number of data points."; } - CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target for `binary:hinge` is not yet supported."; - out_gpair->Reshape(ndata, 1); - common::Transform<>::Init( - [=] XGBOOST_DEVICE(size_t _idx, - common::Span _out_gpair, - common::Span _preds, - common::Span _labels, - common::Span _weights) { - bst_float p = _preds[_idx]; - bst_float w = is_null_weight ? 1.0f : _weights[_idx]; - bst_float y = _labels[_idx] * 2.0 - 1.0; - bst_float g, h; - if (p * y < 1.0) { - g = -y * w; - h = w; - } else { - g = 0.0; - h = std::numeric_limits::min(); - } - _out_gpair[_idx] = GradientPair(g, h); - }, - common::Range{0, static_cast(ndata)}, this->ctx_->Threads(), - ctx_->Device()).Eval( - out_gpair->Data(), &preds, info.labels.Data(), &info.weights_); + + bst_target_t n_targets = this->Targets(info); + out_gpair->Reshape(info.num_row_, n_targets); + auto gpair = out_gpair->View(ctx_->Device()); + + preds.SetDevice(ctx_->Device()); + auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, n_targets); + + auto labels = info.labels.View(ctx_->Device()); + + info.weights_.SetDevice(ctx_->Device()); + common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan() + : info.weights_.ConstHostSpan()}; + + linalg::ElementWiseKernel(this->ctx_, labels, + [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable { + auto w = weight[i]; + + auto p = predt(i, j); + auto y = labels(i, j) * 2.0 - 1.0; + + float g, h; + if (p * y < 1.0) { + g = -y * w; + h = w; + } else { + g = 0.0; + h = std::numeric_limits::min(); + } + gpair(i, j) = GradientPair{g, h}; + }); } - void PredTransform(HostDeviceVector *io_preds) const override { + void PredTransform(HostDeviceVector *io_preds) const override { common::Transform<>::Init( - [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { + [] XGBOOST_DEVICE(std::size_t _idx, common::Span _preds) { _preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0; }, common::Range{0, static_cast(io_preds->Size()), 1}, this->ctx_->Threads(), @@ -76,12 +90,10 @@ class HingeObj : public ObjFunction { .Eval(io_preds); } - [[nodiscard]] const char* DefaultEvalMetric() const override { - return "error"; - } + [[nodiscard]] const char *DefaultEvalMetric() const override { return "error"; } - void SaveConfig(Json* p_out) const override { - auto& out = *p_out; + void SaveConfig(Json *p_out) const override { + auto &out = *p_out; out["name"] = String("binary:hinge"); } void LoadConfig(Json const &) override {} @@ -89,7 +101,7 @@ class HingeObj : public ObjFunction { // register the objective functions XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge") -.describe("Hinge loss. Expects labels to be in [0,1f]") -.set_body([]() { return new HingeObj(); }); + .describe("Hinge loss. Expects labels to be in [0,1f]") + .set_body([]() { return new HingeObj(); }); } // namespace xgboost::obj diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index 57f432c7f161..15ec72f95d91 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -75,28 +75,25 @@ class QuantileRegression : public ObjFunction { : info.weights_.ConstHostSpan()}; preds.SetDevice(ctx_->Device()); - auto predt = linalg::MakeVec(&preds); - auto n_samples = info.num_row_; + auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, n_targets); alpha_.SetDevice(ctx_->Device()); auto alpha = ctx_->IsCUDA() ? alpha_.ConstDeviceSpan() : alpha_.ConstHostSpan(); - linalg::ElementWiseKernel( - ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable { - auto [sample_id, quantile_id, target_id] = - linalg::UnravelIndex(i, n_samples, alpha.size(), n_targets / alpha.size()); - assert(target_id == 0); - - auto d = predt(i) - labels(sample_id, target_id); - auto h = weight[sample_id]; - if (d >= 0) { - auto g = (1.0f - alpha[quantile_id]) * weight[sample_id]; - gpair(sample_id, quantile_id) = GradientPair{g, h}; - } else { - auto g = (-alpha[quantile_id] * weight[sample_id]); - gpair(sample_id, quantile_id) = GradientPair{g, h}; - } - }); + linalg::ElementWiseKernel(ctx_, gpair, + [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable { + // j is the quantile index + // 0 is the target index + auto d = predt(i, j) - labels(i, 0); + auto h = weight[i]; + if (d >= 0) { + auto g = (1.0f - alpha[j]) * weight[i]; + gpair(i, j) = GradientPair{g, h}; + } else { + auto g = (-alpha[j] * weight[i]); + gpair(i, j) = GradientPair{g, h}; + } + }); } void InitEstimation(MetaInfo const& info, linalg::Vector* base_score) const override { diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index f74d01acc74d..5627600fc187 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -255,24 +255,24 @@ class PseudoHuberRegression : public FitIntercept { auto gpair = out_gpair->View(ctx_->Device()); preds.SetDevice(ctx_->Device()); - auto predt = linalg::MakeVec(&preds); + auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info)); info.weights_.SetDevice(ctx_->Device()); common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan() : info.weights_.ConstHostSpan()}; - linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable { - auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape())); - const float z = predt(i) - y; - const float scale_sqrt = std::sqrt(1 + common::Sqr(z) / common::Sqr(slope)); - float grad = z / scale_sqrt; + linalg::ElementWiseKernel( + ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable { + float z = predt(i, j) - labels(i, j); + float scale_sqrt = std::sqrt(1 + common::Sqr(z) / common::Sqr(slope)); + float grad = z / scale_sqrt; - auto scale = common::Sqr(slope) + common::Sqr(z); - float hess = common::Sqr(slope) / (scale * scale_sqrt); + auto scale = common::Sqr(slope) + common::Sqr(z); + float hess = common::Sqr(slope) / (scale * scale_sqrt); - auto w = weight[sample_id]; - gpair(i) = {grad * w, hess * w}; - }); + auto w = weight[i]; + gpair(i) = {grad * w, hess * w}; + }); } [[nodiscard]] const char* DefaultEvalMetric() const override { return "mphe"; } @@ -635,20 +635,21 @@ class MeanAbsoluteError : public ObjFunction { auto gpair = out_gpair->View(ctx_->Device()); preds.SetDevice(ctx_->Device()); - auto predt = linalg::MakeVec(&preds); + auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info)); info.weights_.SetDevice(ctx_->Device()); common::OptionalWeights weight{ctx_->IsCUDA() ? info.weights_.ConstDeviceSpan() : info.weights_.ConstHostSpan()}; - linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, float y) mutable { - auto sign = [](auto x) { - return (x > static_cast(0)) - (x < static_cast(0)); - }; - auto [sample_id, target_id] = linalg::UnravelIndex(i, labels.Shape()); - auto grad = sign(predt(i) - y) * weight[sample_id]; - auto hess = weight[sample_id]; - gpair(sample_id, target_id) = GradientPair{grad, hess}; - }); + linalg::ElementWiseKernel( + ctx_, labels, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable { + auto sign = [](auto x) { + return (x > static_cast(0)) - (x < static_cast(0)); + }; + auto y = labels(i, j); + auto hess = weight[i]; + auto grad = sign(predt(i, j) - y) * hess; + gpair(i, j) = GradientPair{grad, hess}; + }); } void InitEstimation(MetaInfo const& info, linalg::Tensor* base_margin) const override { diff --git a/tests/cpp/common/test_linalg.cu b/tests/cpp/common/test_linalg.cu index 4823b1191088..5f8bab4a3cc4 100644 --- a/tests/cpp/common/test_linalg.cu +++ b/tests/cpp/common/test_linalg.cu @@ -23,7 +23,7 @@ void TestElementWiseKernel() { ElementWiseTransformDevice(t, [] __device__(size_t i, float) { return i; }); // CPU view t = l.View(DeviceOrd::CPU()).Slice(linalg::All(), 1, linalg::All()); - size_t k = 0; + std::size_t k = 0; for (size_t i = 0; i < l.Shape(0); ++i) { for (size_t j = 0; j < l.Shape(2); ++j) { ASSERT_EQ(k++, t(i, j)); @@ -31,7 +31,15 @@ void TestElementWiseKernel() { } t = l.View(device).Slice(linalg::All(), 1, linalg::All()); - ElementWiseKernelDevice(t, [] XGBOOST_DEVICE(size_t i, float v) { SPAN_CHECK(v == i); }); + cuda_impl::ElementWiseKernel( + t, [=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable { t(i, j) = i + j; }); + + t = l.Slice(linalg::All(), 1, linalg::All()); + for (size_t i = 0; i < l.Shape(0); ++i) { + for (size_t j = 0; j < l.Shape(2); ++j) { + ASSERT_EQ(i + j, t(i, j)); + } + } } { diff --git a/tests/cpp/data/test_metainfo.h b/tests/cpp/data/test_metainfo.h index fba882e0e66e..92cd6cb91e43 100644 --- a/tests/cpp/data/test_metainfo.h +++ b/tests/cpp/data/test_metainfo.h @@ -31,12 +31,10 @@ inline void TestMetaInfoStridedData(DeviceOrd device) { auto const& h_result = info.labels.View(DeviceOrd::CPU()); ASSERT_EQ(h_result.Shape().size(), 2); auto in_labels = labels.View(DeviceOrd::CPU()); - linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float& v_0) { - auto tup = linalg::UnravelIndex(i, h_result.Shape()); - auto i0 = std::get<0>(tup); - auto i1 = std::get<1>(tup); + linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, std::size_t j) { // Sliced at second dimension. - auto v_1 = in_labels(i0, 0, i1); + auto v_0 = h_result(i, j); + auto v_1 = in_labels(i, 0, j); CHECK_EQ(v_0, v_1); }); } @@ -65,14 +63,13 @@ inline void TestMetaInfoStridedData(DeviceOrd device) { auto const& h_result = info.base_margin_.View(DeviceOrd::CPU()); ASSERT_EQ(h_result.Shape().size(), 2); auto in_margin = base_margin.View(DeviceOrd::CPU()); - linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), [&](size_t i, float v_0) { - auto tup = linalg::UnravelIndex(i, h_result.Shape()); - auto i0 = std::get<0>(tup); - auto i1 = std::get<1>(tup); - // Sliced at second dimension. - auto v_1 = in_margin(i0, 0, i1); - CHECK_EQ(v_0, v_1); - }); + linalg::ElementWiseKernelHost(h_result, omp_get_max_threads(), + [&](std::size_t i, std::size_t j) { + // Sliced at second dimension. + auto v_0 = h_result(i, j); + auto v_1 = in_margin(i, 0, j); + CHECK_EQ(v_0, v_1); + }); } } } // namespace xgboost diff --git a/tests/cpp/objective/test_hinge.cc b/tests/cpp/objective/test_hinge.cc index 17d2609d4ff1..70e8b5626a22 100644 --- a/tests/cpp/objective/test_hinge.cc +++ b/tests/cpp/objective/test_hinge.cc @@ -1,28 +1,55 @@ -// Copyright by Contributors +/** + * Copyright 2018-2023, XGBoost Contributors + */ #include #include #include #include "../helpers.h" +#include "../../../src/common/linalg_op.h" namespace xgboost { TEST(Objective, DeclareUnifiedTest(HingeObj)) { Context ctx = MakeCUDACtx(GPUIDX); std::unique_ptr obj{ObjFunction::Create("binary:hinge", &ctx)}; float eps = std::numeric_limits::min(); - CheckObjFunction(obj, - {-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f}, - { 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f}, - { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, - { 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f}, - { eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps }); - CheckObjFunction(obj, - {-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f}, - { 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f}, - {}, // Empty weight. - { 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f}, - { eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps }); + std::vector predt{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f}; + std::vector label{ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + std::vector grad{0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f}; + std::vector hess{eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps}; - ASSERT_NO_THROW(obj->DefaultEvalMetric()); + CheckObjFunction(obj, predt, label, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, grad, hess); + CheckObjFunction(obj, predt, label, {/* Empty weight. */}, grad, hess); + + ASSERT_EQ(obj->DefaultEvalMetric(), StringView{"error"}); + + MetaInfo info; + info.num_row_ = label.size(); + info.labels.Reshape(info.num_row_, 3); + ASSERT_EQ(obj->Targets(info), 3); + auto h_labels = info.labels.HostView(); + for (std::size_t j = 0; j < obj->Targets(info); ++j) { + for (std::size_t i = 0; i < info.num_row_; ++i) { + h_labels(i, j) = label[i]; + } + } + linalg::Tensor t_predt{}; + t_predt.Reshape(info.labels.Shape()); + for (std::size_t j = 0; j < obj->Targets(info); ++j) { + for (std::size_t i = 0; i < info.num_row_; ++i) { + t_predt(i, j) = predt[i]; + } + } + linalg::Matrix out_gpair; + obj->GetGradient(*t_predt.Data(), info, 0, &out_gpair); + + for (std::size_t j = 0; j < obj->Targets(info); ++j) { + auto gh = out_gpair.Slice(linalg::All(), j); + ASSERT_EQ(gh.Size(), info.num_row_); + for (std::size_t i = 0; i < gh.Size(); ++i) { + ASSERT_EQ(gh(i).GetGrad(), grad[i]); + ASSERT_EQ(gh(i).GetHess(), hess[i]); + } + } } } // namespace xgboost