Skip to content

Commit

Permalink
Avoid thrust logical operation. (#9199)
Browse files Browse the repository at this point in the history
Thrust implementation of `thrust::all_of/any_of/none_of` adopts an early stopping strategy
to bailout early by dividing the input into small batches. This is not ideal for data
validation as we expect all data to be valid. The strategy leads to excessive kernel
launches and stream synchronization.

* Use reduce from dh instead.
  • Loading branch information
trivialfis authored May 26, 2023
1 parent 614f47c commit 053aaba
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
25 changes: 18 additions & 7 deletions src/data/device_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
: columns_(columns),
num_rows_(num_rows) {}
size_t Size() const { return num_rows_ * columns_.size(); }
__device__ COOTuple GetElement(size_t idx) const {
__device__ __forceinline__ COOTuple GetElement(size_t idx) const {
size_t column_idx = idx % columns_.size();
size_t row_idx = idx / columns_.size();
auto const& column = columns_[column_idx];
Expand Down Expand Up @@ -221,13 +221,24 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
* \brief Check there's no inf in data.
*/
template <typename AdapterBatchT>
bool HasInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) {
bool NoInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) {
auto counting = thrust::make_counting_iterator(0llu);
auto value_iter = dh::MakeTransformIterator<float>(
counting, [=] XGBOOST_DEVICE(std::size_t idx) { return batch.GetElement(idx).value; });
auto valid =
thrust::none_of(value_iter, value_iter + batch.Size(),
[is_valid] XGBOOST_DEVICE(float v) { return is_valid(v) && std::isinf(v); });
auto value_iter = dh::MakeTransformIterator<bool>(counting, [=] XGBOOST_DEVICE(std::size_t idx) {
auto v = batch.GetElement(idx).value;
if (!is_valid(v)) {
// discard the invalid elements.
return true;
}
// check that there's no inf in data.
return !std::isinf(v);
});
dh::XGBCachingDeviceAllocator<char> alloc;
// The default implementation in thrust optimizes any_of/none_of/all_of by using small
// intervals to early stop. But we expect all data to be valid here, using small
// intervals only decreases performance due to excessive kernel launch and stream
// synchronization.
auto valid = dh::Reduce(thrust::cuda::par(alloc), value_iter, value_iter + batch.Size(), true,
thrust::logical_and<>{});
return valid;
}
}; // namespace data
Expand Down
2 changes: 1 addition & 1 deletion src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType cons
// correct output position
auto counting = thrust::make_counting_iterator(0llu);
data::IsValidFunctor is_valid(missing);
bool valid = data::HasInfInData(batch, is_valid);
bool valid = data::NoInfInData(batch, is_valid);
CHECK(valid) << error::InfInData();

auto key_iter = dh::MakeTransformIterator<size_t>(
Expand Down
2 changes: 1 addition & 1 deletion src/data/simple_dmatrix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
template <typename AdapterBatchT>
size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing,
SparsePage* page) {
bool valid = HasInfInData(batch, IsValidFunctor{missing});
bool valid = NoInfInData(batch, IsValidFunctor{missing});
CHECK(valid) << error::InfInData();

page->offset.SetDevice(device);
Expand Down

0 comments on commit 053aaba

Please sign in to comment.