Skip to content

Commit

Permalink
Use HalfFloatType and CTypeOrFloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
AlenkaF committed Mar 24, 2024
1 parent a43dafe commit ff56c20
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions cpp/src/arrow/record_batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ Result<std::shared_ptr<StructArray>> RecordBatch::ToStructArray() const {
/*offset=*/0);
}

template <typename Num>
using CTypeOrFloat16 = std::conditional_t<Num::type_id == Type::HALF_FLOAT,
uint16_t, typename Num::c_type>;

template <typename Out>
struct ConvertColumnsToTensorVisitor {
Out*& out_values;
Expand All @@ -258,7 +262,7 @@ struct ConvertColumnsToTensorVisitor {
template <typename T>
Status Visit(const T&) {
if constexpr (is_numeric(T::type_id)) {
using In = typename T::c_type;
using In = CTypeOrFloat16<T>;
auto in_values = ArraySpan(in_data).GetSpan<In>(1, in_data.length);

if constexpr (std::is_same_v<In, Out>) {
Expand All @@ -277,7 +281,7 @@ struct ConvertColumnsToTensorVisitor {

template <typename DataType>
inline void ConvertColumnsToTensor(const RecordBatch& batch, uint8_t* out) {
using CType = typename arrow::TypeTraits<DataType>::CType;
using CType = CTypeOrFloat16<DataType>;
auto* out_values = reinterpret_cast<CType*>(out);

for (const auto& column : batch.columns()) {
Expand Down Expand Up @@ -336,7 +340,6 @@ Result<std::shared_ptr<Tensor>> RecordBatch::ToTensor(MemoryPool* pool) const {
ConvertColumnsToTensor<UInt8Type>(*this, result->mutable_data());
break;
case Type::UINT16:
case Type::HALF_FLOAT:
ConvertColumnsToTensor<UInt16Type>(*this, result->mutable_data());
break;
case Type::UINT32:
Expand All @@ -357,6 +360,9 @@ Result<std::shared_ptr<Tensor>> RecordBatch::ToTensor(MemoryPool* pool) const {
case Type::INT64:
ConvertColumnsToTensor<Int64Type>(*this, result->mutable_data());
break;
case Type::HALF_FLOAT:
ConvertColumnsToTensor<HalfFloatType>(*this, result->mutable_data());
break;
case Type::FLOAT:
ConvertColumnsToTensor<FloatType>(*this, result->mutable_data());
break;
Expand Down

0 comments on commit ff56c20

Please sign in to comment.