Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize array interface input. #9090

Merged
merged 4 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/common/error_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,9 @@ constexpr StringView LabelScoreSize() {
constexpr StringView InfInData() {
return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`";
}

constexpr StringView NoF128() {
return "128-bit floating point is not supported on current platform.";
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_
84 changes: 78 additions & 6 deletions src/data/array_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
#define XGBOOST_DATA_ARRAY_INTERFACE_H_

#include <algorithm>
#include <cstddef> // std::size_t
#include <cstddef> // for size_t
#include <cstdint>
#include <limits> // for numeric_limits
#include <map>
#include <string>
#include <type_traits> // std::alignment_of,std::remove_pointer_t
Expand All @@ -17,6 +18,7 @@

#include "../common/bitfield.h"
#include "../common/common.h"
#include "../common/error_msg.h" // for NoF128
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/json.h"
Expand Down Expand Up @@ -454,9 +456,8 @@ class ArrayInterface {
void AssignType(StringView typestr) {
using T = ArrayInterfaceHandler::Type;
if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && typestr[3] == '6') {
CHECK(sizeof(long double) == 16) << error::NoF128();
type = T::kF16;
CHECK(sizeof(long double) == 16)
<< "128-bit floating point is not supported on current platform.";
} else if (typestr[1] == 'f' && typestr[2] == '2') {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
type = T::kF2;
Expand Down Expand Up @@ -572,19 +573,90 @@ class ArrayInterface {
// Used only by columnar format.
RBitField8 valid;
// Array stride
size_t strides[D]{0};
std::size_t strides[D]{0};
// Array shape
size_t shape[D]{0};
std::size_t shape[D]{0};
// Type earsed pointer referencing the data.
void const *data{nullptr};
// Total number of items
size_t n{0};
std::size_t n{0};
// Whether the memory is c-contiguous
bool is_contiguous{false};
// RTTI, initialized to the f16 to avoid masking potential bugs in initialization.
ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
};

template <std::int32_t D, typename Fn>
void DispatchDType(ArrayInterface<D> const array, std::int32_t device, Fn fn) {
// Only used for cuDF at the moment.
CHECK_EQ(array.valid.Size(), 0);
auto dispatch = [&](auto t) {
using T = std::remove_const_t<decltype(t)> const;
// Set the data size to max as we don't know the original size of a sliced array:
//
// Slicing an array A with shape (4, 2, 3) and stride (6, 3, 1) by [:, 1, :] results
// in an array B with shape (4, 3) and strides (6, 1). We can't calculate the original
// size 24 based on the slice.
fn(linalg::TensorView<T, D>{common::Span<T const>{static_cast<T *>(array.data),
std::numeric_limits<std::size_t>::max()},
array.shape, array.strides, device});
};
switch (array.type) {
case ArrayInterfaceHandler::kF2: {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
dispatch(__half{});
#endif
break;
}
case ArrayInterfaceHandler::kF4: {
dispatch(float{});
break;
}
case ArrayInterfaceHandler::kF8: {
dispatch(double{});
break;
}
case ArrayInterfaceHandler::kF16: {
using T = long double;
CHECK(sizeof(long double) == 16) << error::NoF128();
dispatch(T{});
break;
}
case ArrayInterfaceHandler::kI1: {
dispatch(std::int8_t{});
break;
}
case ArrayInterfaceHandler::kI2: {
dispatch(std::int16_t{});
break;
}
case ArrayInterfaceHandler::kI4: {
dispatch(std::int32_t{});
break;
}
case ArrayInterfaceHandler::kI8: {
dispatch(std::int64_t{});
break;
}
case ArrayInterfaceHandler::kU1: {
dispatch(std::uint8_t{});
break;
}
case ArrayInterfaceHandler::kU2: {
dispatch(std::uint16_t{});
break;
}
case ArrayInterfaceHandler::kU4: {
dispatch(std::uint32_t{});
break;
}
case ArrayInterfaceHandler::kU8: {
dispatch(std::uint64_t{});
break;
}
}
}

/**
* \brief Helper for type casting.
*/
Expand Down
11 changes: 7 additions & 4 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,13 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T
return;
}
p_out->Reshape(array.shape);
auto t = p_out->View(Context::kCpuId);
CHECK(t.CContiguous());
linalg::ElementWiseTransformHost(t, ctx.Threads(), [&](auto i, auto) {
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape()));
auto t_out = p_out->View(Context::kCpuId);
CHECK(t_out.CContiguous());
auto const shape = t_out.Shape();
DispatchDType(array, Context::kCpuId, [&](auto&& in) {
linalg::ElementWiseTransformHost(t_out, ctx.Threads(), [&](auto i, auto) {
return std::apply(in, linalg::UnravelIndex<D>(i, shape));
});
});
}
} // namespace
Expand Down