Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize adapter element counting on GPU. #9209

Merged
merged 2 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions include/xgboost/span.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2018 XGBoost contributors
/**
* Copyright 2018-2023, XGBoost contributors
* \brief span class based on ISO++20 span
*
* About NOLINTs in this file:
Expand Down Expand Up @@ -32,11 +32,12 @@
#include <xgboost/base.h>
#include <xgboost/logging.h>

#include <cinttypes> // size_t
#include <limits> // numeric_limits
#include <cinttypes> // size_t
#include <cstdio>
#include <iterator>
#include <limits> // numeric_limits
#include <type_traits>
#include <cstdio>
#include <utility> // for move

#if defined(__CUDACC__)
#include <cuda_runtime.h>
Expand Down Expand Up @@ -668,6 +669,44 @@ XGBOOST_DEVICE auto as_writable_bytes(Span<T, E> s) __span_noexcept -> // NOLIN
Span<byte, detail::ExtentAsBytesValue<T, E>::value> {
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
}

/**
* \brief A simple custom Span type that uses general iterator instead of pointer.
*/
template <typename It>
class IterSpan {
public:
using value_type = typename std::iterator_traits<It>::value_type; // NOLINT
using index_type = std::size_t; // NOLINT
using iterator = It; // NOLINT

private:
It it_;
index_type size_{0};

public:
IterSpan() = default;
XGBOOST_DEVICE IterSpan(It it, index_type size) : it_{std::move(it)}, size_{size} {}
XGBOOST_DEVICE explicit IterSpan(common::Span<It, dynamic_extent> span)
: it_{span.data()}, size_{span.size()} {}

[[nodiscard]] XGBOOST_DEVICE index_type size() const noexcept { return size_; } // NOLINT
[[nodiscard]] XGBOOST_DEVICE decltype(auto) operator[](index_type i) const { return it_[i]; }
[[nodiscard]] XGBOOST_DEVICE decltype(auto) operator[](index_type i) { return it_[i]; }
[[nodiscard]] XGBOOST_DEVICE bool empty() const noexcept { return size() == 0; } // NOLINT
[[nodiscard]] XGBOOST_DEVICE It data() const noexcept { return it_; } // NOLINT
[[nodiscard]] XGBOOST_DEVICE IterSpan<It> subspan( // NOLINT
index_type _offset, index_type _count = dynamic_extent) const {
SPAN_CHECK((_count == dynamic_extent) ? (_offset <= size()) : (_offset + _count <= size()));
return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count};
}
[[nodiscard]] XGBOOST_DEVICE constexpr iterator begin() const noexcept { // NOLINT
return {this, 0};
}
[[nodiscard]] XGBOOST_DEVICE constexpr iterator end() const noexcept { // NOLINT
return {this, size()};
}
};
} // namespace common
} // namespace xgboost

Expand Down
10 changes: 4 additions & 6 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,8 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size.
});
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
batch_it, dummy_is_valid,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
&column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();

if (sketch_container->HasCategorical()) {
Expand Down Expand Up @@ -273,9 +272,8 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
});
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
batch_it, dummy_is_valid,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
&column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
if (sketch_container->HasCategorical()) {
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
Expand Down
143 changes: 117 additions & 26 deletions src/common/hist_util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,126 @@ struct EntryCompareOp {
};

// Get column size from adapter batch and for output cuts.
template <typename Iter>
void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feature,
Iter batch_iter, data::IsValidFunctor is_valid,
size_t begin, size_t end,
HostDeviceVector<SketchContainer::OffsetT> *cuts_ptr,
template <std::uint32_t kBlockThreads, typename CounterT, typename BatchIt>
__global__ void GetColumnSizeSharedMemKernel(IterSpan<BatchIt> batch_iter,
data::IsValidFunctor is_valid,
Span<std::size_t> out_column_size) {
extern __shared__ char smem[];

auto smem_cs_ptr = reinterpret_cast<CounterT*>(smem);

dh::BlockFill(smem_cs_ptr, out_column_size.size(), 0);

cub::CTA_SYNC();

auto n = batch_iter.size();

for (auto idx : dh::GridStrideRange(static_cast<std::size_t>(0), n)) {
auto e = batch_iter[idx];
if (is_valid(e)) {
atomicAdd(&smem_cs_ptr[e.column_idx], static_cast<CounterT>(1));
}
}

cub::CTA_SYNC();

auto out_global_ptr = out_column_size;
for (auto i : dh::BlockStrideRange(static_cast<std::size_t>(0), out_column_size.size())) {
atomicAdd(&out_global_ptr[i], static_cast<std::size_t>(smem_cs_ptr[i]));
}
}

template <std::uint32_t kBlockThreads, typename Kernel>
std::uint32_t EstimateGridSize(std::int32_t device, Kernel kernel, std::size_t shared_mem) {
int n_mps = 0;
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
int n_blocks_per_mp = 0;
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel,
kBlockThreads, shared_mem));
std::uint32_t grid_size = n_blocks_per_mp * n_mps;
return grid_size;
}

/**
* \brief Get the size of each column. This is a histogram with additional handling of
* invalid values.
*
* \tparam BatchIt Type of input adapter batch.
* \tparam force_use_global_memory Used for testing. Force global atomic add.
* \tparam force_use_u64 Used for testing. For u64 as counter in shared memory.
*
* \param device CUDA device ordinal.
* \param batch_iter Iterator for input data from adapter batch.
* \param is_valid Whehter an element is considered as missing.
* \param out_column_size Output buffer for the size of each column.
*/
template <typename BatchIt, bool force_use_global_memory = false, bool force_use_u64 = false>
void LaunchGetColumnSizeKernel(std::int32_t device, IterSpan<BatchIt> batch_iter,
data::IsValidFunctor is_valid, Span<std::size_t> out_column_size) {
thrust::fill_n(thrust::device, dh::tbegin(out_column_size), out_column_size.size(), 0);

std::size_t max_shared_memory = dh::MaxSharedMemory(device);
// Not strictly correct as we should use number of samples to determine the type of
// counter. However, the sample size is not known due to sliding window on number of
// elements.
std::size_t n = batch_iter.size();

std::size_t required_shared_memory = 0;
bool use_u32{false};
if (!force_use_u64 && n < static_cast<std::size_t>(std::numeric_limits<std::uint32_t>::max())) {
required_shared_memory = out_column_size.size() * sizeof(std::uint32_t);
use_u32 = true;
} else {
required_shared_memory = out_column_size.size() * sizeof(std::size_t);
use_u32 = false;
}
bool use_shared = required_shared_memory <= max_shared_memory && required_shared_memory != 0;

if (!force_use_global_memory && use_shared) {
CHECK_NE(required_shared_memory, 0);
std::uint32_t constexpr kBlockThreads = 512;
if (use_u32) {
CHECK(!force_use_u64);
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::uint32_t, BatchIt>;
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}(
kernel, batch_iter, is_valid, out_column_size);
} else {
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::size_t, BatchIt>;
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}(
kernel, batch_iter, is_valid, out_column_size);
}
} else {
auto d_out_column_size = out_column_size;
dh::LaunchN(batch_iter.size(), [=] __device__(size_t idx) {
auto e = batch_iter[idx];
if (is_valid(e)) {
atomicAdd(&d_out_column_size[e.column_idx], static_cast<size_t>(1));
}
});
}
}

template <typename BatchIt>
void GetColumnSizesScan(int device, size_t num_columns, std::size_t num_cuts_per_feature,
IterSpan<BatchIt> batch_iter, data::IsValidFunctor is_valid,
HostDeviceVector<SketchContainer::OffsetT>* cuts_ptr,
dh::caching_device_vector<size_t>* column_sizes_scan) {
column_sizes_scan->resize(num_columns + 1, 0);
column_sizes_scan->resize(num_columns + 1);
cuts_ptr->SetDevice(device);
cuts_ptr->Resize(num_columns + 1, 0);

dh::XGBCachingDeviceAllocator<char> alloc;
auto d_column_sizes_scan = column_sizes_scan->data().get();
dh::LaunchN(end - begin, [=] __device__(size_t idx) {
auto e = batch_iter[begin + idx];
if (is_valid(e)) {
atomicAdd(&d_column_sizes_scan[e.column_idx], static_cast<size_t>(1));
}
});
auto d_column_sizes_scan = dh::ToSpan(*column_sizes_scan);
LaunchGetColumnSizeKernel(device, batch_iter, is_valid, d_column_sizes_scan);
// Calculate cuts CSC pointer
auto cut_ptr_it = dh::MakeTransformIterator<size_t>(
column_sizes_scan->begin(), [=] __device__(size_t column_size) {
return thrust::min(num_cuts_per_feature, column_size);
});
thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it,
cut_ptr_it + column_sizes_scan->size(),
cuts_ptr->DevicePointer());
cut_ptr_it + column_sizes_scan->size(), cuts_ptr->DevicePointer());
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
}
Expand Down Expand Up @@ -121,29 +215,26 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,

// Count the valid entries in each column and copy them out.
template <typename AdapterBatch, typename BatchIter>
void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
Range1d range, float missing,
size_t columns, size_t cuts_per_feature, int device,
void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Range1d range,
float missing, size_t columns, size_t cuts_per_feature, int device,
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
dh::caching_device_vector<size_t>* column_sizes_scan,
dh::device_vector<Entry>* sorted_entries) {
auto entry_iter = dh::MakeTransformIterator<Entry>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
return Entry(batch.GetElement(idx).column_idx,
batch.GetElement(idx).value);
return Entry(batch.GetElement(idx).column_idx, batch.GetElement(idx).value);
});
auto n = range.end() - range.begin();
auto span = IterSpan{batch_iter + range.begin(), n};
data::IsValidFunctor is_valid(missing);
// Work out how many valid entries we have in each column
GetColumnSizesScan(device, columns, cuts_per_feature,
batch_iter, is_valid,
range.begin(), range.end(),
cut_sizes_scan,
GetColumnSizesScan(device, columns, cuts_per_feature, span, is_valid, cut_sizes_scan,
column_sizes_scan);
size_t num_valid = column_sizes_scan->back();
// Copy current subset of valid elements into temporary storage and sort
sorted_entries->resize(num_valid);
dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(),
sorted_entries->begin(), is_valid);
dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(), sorted_entries->begin(),
is_valid);
}

void SortByWeight(dh::device_vector<float>* weights,
Expand Down
55 changes: 45 additions & 10 deletions src/data/device_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
return {row_idx, column_idx, value};
}

__device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const {
auto const& column = columns_[fidx];
float value = column.valid.Data() == nullptr || column.valid.Check(ridx)
? column(ridx)
: std::numeric_limits<float>::quiet_NaN();
return value;
}

XGBOOST_DEVICE bst_row_t NumRows() const { return num_rows_; }
XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); }

Expand Down Expand Up @@ -160,6 +168,10 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
float value = array_interface_(row_idx, column_idx);
return {row_idx, column_idx, value};
}
__device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const {
float value = array_interface_(ridx, fidx);
return value;
}

XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.Shape(0); }
XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.Shape(1); }
Expand Down Expand Up @@ -196,24 +208,47 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {

// Returns maximum row length
template <typename AdapterBatchT>
size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
int device_idx, float missing) {
std::size_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_row_t> offset, int device_idx,
float missing) {
dh::safe_cuda(cudaSetDevice(device_idx));
IsValidFunctor is_valid(missing);
dh::safe_cuda(cudaMemsetAsync(offset.data(), '\0', offset.size_bytes()));

auto n_samples = batch.NumRows();
bst_feature_t n_features = batch.NumCols();

// Use more than 1 threads for each row in case of dataset being too wide.
bst_feature_t stride{0};
if (n_features < 32) {
stride = std::min(n_features, 4u);
} else if (n_features < 64) {
stride = 8;
} else if (n_features < 128) {
stride = 16;
} else {
stride = 32;
}

// Count elements per row
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
auto element = batch.GetElement(idx);
if (is_valid(element)) {
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&offset[element.row_idx]),
static_cast<unsigned long long>(1)); // NOLINT
dh::LaunchN(n_samples * stride, [=] __device__(std::size_t idx) {
bst_row_t cnt{0};
auto [ridx, fbeg] = linalg::UnravelIndex(idx, n_samples, stride);
SPAN_CHECK(ridx < n_samples);
for (bst_feature_t fidx = fbeg; fidx < n_features; fidx += stride) {
if (is_valid(batch.GetElement(ridx, fidx))) {
cnt++;
}
}

atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&offset[ridx]),
static_cast<unsigned long long>(cnt)); // NOLINT
});
dh::XGBCachingDeviceAllocator<char> alloc;
size_t row_stride =
bst_row_t row_stride =
dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data()) + offset.size(),
static_cast<std::size_t>(0), thrust::maximum<size_t>());
static_cast<bst_row_t>(0), thrust::maximum<bst_row_t>());
return row_stride;
}

Expand Down
4 changes: 2 additions & 2 deletions src/data/iterative_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
}
auto batch_rows = num_rows();
accumulated_rows += batch_rows;
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
dh::device_vector<size_t> row_counts(batch_rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, get_device(), missing);
Expand Down Expand Up @@ -134,7 +134,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
init_page();
dh::safe_cuda(cudaSetDevice(get_device()));
auto rows = num_rows();
dh::caching_device_vector<size_t> row_counts(rows + 1, 0);
dh::device_vector<size_t> row_counts(rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, get_device(), missing);
Expand Down
Loading