Skip to content

Commit

Permalink
Improve update position function
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Nov 14, 2018
1 parent 143475b commit 3d4e277
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 57 deletions.
74 changes: 73 additions & 1 deletion src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,8 @@ typename std::iterator_traits<T>::value_type SumReduction(
dh::CubMemory &tmp_mem, T in, int nVals) {
using ValueT = typename std::iterator_traits<T>::value_type;
size_t tmpSize;
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals));
ValueT *dummy_out = nullptr;
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, dummy_out, nVals));
// Allocate small extra memory for the return value
tmp_mem.LazyAllocate(tmpSize + sizeof(ValueT));
auto ptr = reinterpret_cast<ValueT *>(tmp_mem.d_temp_storage) + 1;
Expand Down Expand Up @@ -1074,4 +1075,75 @@ xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
using IndexT = typename xgboost::common::Span<T>::index_type;
return ToSpan(vec, static_cast<IndexT>(offset), static_cast<IndexT>(size));
}

template <typename func_t>
class LauncherItr {
public:
int idx;
func_t f;
XGBOOST_DEVICE LauncherItr() : idx(0) {}
XGBOOST_DEVICE LauncherItr(int idx, func_t f) : idx(idx), f(f) {}
XGBOOST_DEVICE LauncherItr &operator=(int output) {
f(idx, output);
return *this;
}
};
template <typename func_t>
/**
* \class DiscardLambdaItr
*
* \brief Thrust compatible iterator type - discards algorithm output and
* launches device lambda with the index of the output and the algorithm output as arguments.
*
* \author Rory
* \date 7/9/2017
*/
class DiscardLambdaItr {
public:
// Required iterator traits
typedef DiscardLambdaItr self_type; ///< My own type
typedef ptrdiff_t
difference_type; ///< Type to express the result of subtracting
/// one iterator from another
typedef void
value_type; ///< The type of the element the iterator can point to
typedef value_type *pointer; ///< The type of a pointer to an element the
/// iterator can point to
typedef LauncherItr<func_t> reference; ///< The type of a reference to an element the
/// iterator can point to
typedef typename thrust::detail::iterator_facade_category<
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
reference>::type iterator_category; ///< The iterator category
private:
difference_type offset;
func_t f;
public:
XGBOOST_DEVICE DiscardLambdaItr(func_t f) : offset(0), f(f) {}
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, func_t f)
: offset(offset), f(f) {}
XGBOOST_DEVICE self_type operator+(const int &b) const {
return DiscardLambdaItr(offset + b, f);
}
XGBOOST_DEVICE self_type operator++() {
offset++;
return *this;
}
XGBOOST_DEVICE self_type operator++(int) {
self_type retval = *this;
offset++;
return retval;
}
XGBOOST_DEVICE self_type &operator+=(const int &b) {
offset += b;
return *this;
}
XGBOOST_DEVICE reference operator*() const {
return LauncherItr<func_t>(offset, f);
}
XGBOOST_DEVICE reference operator[](int idx) {
self_type offset = (*this) + idx;
return *offset;
}
};

} // namespace dh
130 changes: 76 additions & 54 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,53 @@ struct Segment {
size_t Size() const { return end - begin; }
};

/** \brief Returns a one if the left node index is encountered, otherwise return
* zero. */
struct IndicateLeftTransform {
int left_nidx;
explicit IndicateLeftTransform(int left_nidx) : left_nidx(left_nidx) {}
__host__ __device__ __forceinline__ int operator()(const int& x) const {
return x == left_nidx ? 1 : 0;
}
};

/**
* \brief Optimised routine for sorting key value pairs into left and right
* segments. Based on a single pass of exclusive scan, uses iterators to
* redirect inputs and outputs.
*/
void SortPosition(dh::CubMemory* temp_memory, common::Span<int> position,
common::Span<int> position_out, common::Span<bst_uint> ridx,
common::Span<bst_uint> ridx_out, int left_nidx,
int right_nidx, int64_t left_count) {
auto d_position_out = position_out.data();
auto d_position_in = position.data();
auto d_ridx_out = ridx_out.data();
auto d_ridx_in = ridx.data();
auto write_results = [=] __device__(size_t idx, int ex_scan_result) {
int scatter_address;
if (d_position_in[idx] == left_nidx) {
scatter_address = ex_scan_result;
} else {
scatter_address = (idx - ex_scan_result) + left_count;
}
d_position_out[scatter_address] = d_position_in[idx];
d_ridx_out[scatter_address] = d_ridx_in[idx];
}; // NOLINT

IndicateLeftTransform conversion_op(left_nidx);
cub::TransformInputIterator<int, IndicateLeftTransform, int*> in_itr(
d_position_in, conversion_op);
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
position.size());
temp_memory->LazyAllocate(temp_storage_bytes);
cub::DeviceScan::ExclusiveSum(temp_memory->d_temp_storage,
temp_memory->temp_storage_bytes, in_itr,
out_itr, position.size());
}

struct DeviceShard;

struct GPUHistBuilderBase {
Expand Down Expand Up @@ -440,26 +487,22 @@ struct DeviceShard {
TrainParam param;
bool prediction_cache_initialised;

int64_t* tmp_pinned; // Small amount of staging memory

dh::CubMemory temp_memory;

std::unique_ptr<GPUHistBuilderBase> hist_builder;

// TODO(canonizer): do add support multi-batch DMatrix here
DeviceShard(int device_id,
bst_uint row_begin, bst_uint row_end, TrainParam _param) :
device_id_(device_id),
row_begin_idx(row_begin),
row_end_idx(row_end),
row_stride(0),
n_rows(row_end - row_begin),
n_bins(0),
null_gidx_value(0),
param(_param),
prediction_cache_initialised(false),
tmp_pinned(nullptr)
{}
DeviceShard(int device_id, bst_uint row_begin, bst_uint row_end,
TrainParam _param)
: device_id_(device_id),
row_begin_idx(row_begin),
row_end_idx(row_end),
row_stride(0),
n_rows(row_end - row_begin),
n_bins(0),
null_gidx_value(0),
param(_param),
prediction_cache_initialised(false) {}

/* Init row_ptrs and row_stride */
void InitRowPtrs(const SparsePage& row_batch) {
Expand Down Expand Up @@ -495,7 +538,6 @@ struct DeviceShard {
void CreateHistIndices(const SparsePage& row_batch);

~DeviceShard() {
dh::safe_cuda(cudaFreeHost(tmp_pinned));
}

// Reset values for each update iteration
Expand Down Expand Up @@ -587,29 +629,18 @@ struct DeviceShard {
hist.HistogramExists(nidx_parent);
}

/*! \brief Count how many rows are assigned to left node. */
__device__ void CountLeft(int64_t* d_count, int val, int left_nidx) {
unsigned ballot = __ballot(val == left_nidx);
if (threadIdx.x % 32 == 0) {
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
}
}

void UpdatePosition(int nidx, int left_nidx, int right_nidx, int fidx,
int64_t split_gidx, bool default_dir_left, bool is_dense,
int fidx_begin, // cut.row_ptr[fidx]
int fidx_end) { // cut.row_ptr[fidx + 1]
dh::safe_cuda(cudaSetDevice(device_id_));
auto d_left_count = temp_memory.GetSpan<int64_t>(1);
dh::safe_cuda(cudaMemset(d_left_count.data(), 0, sizeof(int64_t)));
Segment segment = ridx_segments[nidx];
bst_uint* d_ridx = ridx.Current();
int* d_position = position.Current();
common::CompressedIterator<uint32_t> d_gidx = gidx;
size_t row_stride = this->row_stride;
// Launch 1 thread for each row
dh::LaunchN<1, 512>(
dh::LaunchN<1, 128>(
device_id_, segment.Size(), [=] __device__(bst_uint idx) {
idx += segment.begin;
bst_uint ridx = d_ridx[idx];
Expand All @@ -634,13 +665,16 @@ struct DeviceShard {
position = default_dir_left ? left_nidx : right_nidx;
}

CountLeft(d_left_count.data(), position, left_nidx);
d_position[idx] = position;
});
dh::safe_cuda(cudaMemcpy(tmp_pinned, d_left_count.data(), sizeof(int64_t),
cudaMemcpyDeviceToHost));
auto left_count = *tmp_pinned;
SortPosition(segment, left_nidx, right_nidx);
IndicateLeftTransform conversion_op(left_nidx);
cub::TransformInputIterator<int, IndicateLeftTransform, int*> left_itr(
d_position + segment.begin, conversion_op);
int left_count = dh::SumReduction(temp_memory, left_itr, segment.Size());
CHECK_LE(left_count, segment.Size());
CHECK_GE(left_count, 0);

SortPositionAndCopy(segment, left_nidx, right_nidx, left_count);

ridx_segments[left_nidx] =
Segment(segment.begin, segment.begin + left_count);
Expand All @@ -649,25 +683,15 @@ struct DeviceShard {
}

/*! \brief Sort row indices according to position. */
void SortPosition(const Segment& segment, int left_nidx, int right_nidx) {
int min_bits = 0;
int max_bits = static_cast<int>(
std::ceil(std::log2((std::max)(left_nidx, right_nidx) + 1)));

size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairs(
nullptr, temp_storage_bytes,
position.Current() + segment.begin, position.other() + segment.begin,
ridx.Current() + segment.begin, ridx.other() + segment.begin,
segment.Size(), min_bits, max_bits);

temp_memory.LazyAllocate(temp_storage_bytes);

cub::DeviceRadixSort::SortPairs(
temp_memory.d_temp_storage, temp_memory.temp_storage_bytes,
position.Current() + segment.begin, position.other() + segment.begin,
ridx.Current() + segment.begin, ridx.other() + segment.begin,
segment.Size(), min_bits, max_bits);
void SortPositionAndCopy(const Segment& segment, int left_nidx, int right_nidx,
size_t left_count) {
SortPosition(
&temp_memory,
common::Span<int>(position.Current() + segment.begin, segment.Size()),
common::Span<int>(position.other() + segment.begin, segment.Size()),
common::Span<bst_uint>(ridx.Current() + segment.begin, segment.Size()),
common::Span<bst_uint>(ridx.other() + segment.begin, segment.Size()),
left_nidx, right_nidx, left_count);
// Copy back key
dh::safe_cuda(cudaMemcpy(
position.Current() + segment.begin, position.other() + segment.begin,
Expand Down Expand Up @@ -823,8 +847,6 @@ inline void DeviceShard::InitCompressedData(

// Init histogram
hist.Init(device_id_, hmat.row_ptr.back());

dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t)));
}

inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) {
Expand Down
41 changes: 39 additions & 2 deletions tests/cpp/tree/test_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,6 @@ TEST(GpuHist, ApplySplit) {
shard->row_stride = n_cols;
thrust::sequence(shard->ridx.CurrentDVec().tbegin(),
shard->ridx.CurrentDVec().tend());
// Free inside DeviceShard
dh::safe_cuda(cudaMallocHost(&(shard->tmp_pinned), sizeof(int64_t)));
// Initialize GPUHistMaker
hist_maker.param_ = param;
RegTree tree;
Expand Down Expand Up @@ -389,5 +387,44 @@ TEST(GpuHist, ApplySplit) {
ASSERT_EQ(shard->ridx_segments[right_nidx].end, 16);
}

void TestSortPosition(const std::vector<int>& position_in, int left_idx,
int right_idx) {
int left_count = std::count(position_in.begin(), position_in.end(), left_idx);
thrust::device_vector<int> position = position_in;
thrust::device_vector<int> position_out(position.size());

thrust::device_vector<bst_uint> ridx(position.size());
thrust::sequence(ridx.begin(), ridx.end());
thrust::device_vector<bst_uint> ridx_out(ridx.size());
dh::CubMemory tmp;
SortPosition(
&tmp, common::Span<int>(position.data().get(), position.size()),
common::Span<int>(position_out.data().get(), position_out.size()),
common::Span<bst_uint>(ridx.data().get(), ridx.size()),
common::Span<bst_uint>(ridx_out.data().get(), ridx_out.size()), left_idx,
right_idx, left_count);
thrust::host_vector<int> position_result = position_out;
thrust::host_vector<int> ridx_result = ridx_out;

// Check position is sorted
EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end()));
// Check row indices are sorted inside left and right segment
EXPECT_TRUE(
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count));
EXPECT_TRUE(
std::is_sorted(ridx_result.begin() + left_count, ridx_result.end()));

// Check key value pairs are the same
for (auto i = 0ull; i < ridx_result.size(); i++) {
EXPECT_EQ(position_result[i], position_in[ridx_result[i]]);
}
}

TEST(GpuHist, SortPosition) {
TestSortPosition({1, 2, 1, 2, 1}, 1, 2);
TestSortPosition({1, 1, 1, 1}, 1, 2);
TestSortPosition({2, 2, 2, 2}, 1, 2);
TestSortPosition({1, 2, 1, 2, 3}, 1, 2);
}
} // namespace tree
} // namespace xgboost

0 comments on commit 3d4e277

Please sign in to comment.