From b5cf637e8d0da49fb666f40459aac50a2907bfde Mon Sep 17 00:00:00 2001 From: AlenkaF Date: Thu, 21 Mar 2024 11:48:09 +0100 Subject: [PATCH] Use visitor pattern --- cpp/src/arrow/record_batch.cc | 70 ++++++++++++++--------------------- 1 file changed, 28 insertions(+), 42 deletions(-) diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index d9d2cd4b6557b..f93c81d416686 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -34,7 +34,9 @@ #include "arrow/type.h" #include "arrow/util/iterator.h" #include "arrow/util/logging.h" +#include "arrow/util/unreachable.h" #include "arrow/util/vector.h" +#include "arrow/visit_scalar_inline.h" namespace arrow { @@ -248,58 +250,42 @@ Result> RecordBatch::ToStructArray() const { /*offset=*/0); } -#define TYPE_CASE(type) \ - case Type::type: { \ - using T = typename TypeIdTraits::Type; \ - using CType_in = typename TypeTraits::CType; \ - auto* in_values = batch.column(i)->data()->GetValues(1); \ - for (int64_t i = 0; i < length; ++i) { \ - *out_values++ = static_cast(*in_values++); \ - } \ - break; \ +template +struct ConvertColumnsToTensorVisitor { + Out*& out_values; + const ArrayData& in_data; + + template + Status Visit(const T&) { + if constexpr (is_numeric(T::type_id)) { + using In = typename T::c_type; + auto in_values = ArraySpan(in_data).GetSpan(1, in_data.length); + + if constexpr (std::is_same_v) { + memcpy(out_values, in_values.data(), in_values.size_bytes()); + out_values += in_values.size(); + } else { + for (In in_value : in_values) { + *out_values++ = static_cast(in_value); + } + } + return Status::OK(); + } + Unreachable(); } +}; template inline void ConvertColumnsToTensor(const RecordBatch& batch, uint8_t* out) { using CType = typename arrow::TypeTraits::CType; auto* out_values = reinterpret_cast(out); - int64_t length = batch.num_rows(); - for (int i = 0; i < batch.num_columns(); ++i) { - // If the column is of the same type than resulting data type - if (TypeTraits::type_singleton() == batch.column(i)->type()) { - const auto* in_values = batch.column(i)->data()->GetValues(1); - - memcpy(out_values, in_values, sizeof(CType) * length); - out_values += length; - } else { // If the column is different type than resulting data type - switch (batch.column(i)->type_id()) { - case Type::HALF_FLOAT: { - auto* in_values = batch.column(i)->data()->GetValues(1); - for (int64_t i = 0; i < length; ++i) { - *out_values++ = static_cast(*in_values++); - } - break; - } - TYPE_CASE(UINT8) - TYPE_CASE(UINT16) - TYPE_CASE(UINT32) - TYPE_CASE(UINT64) - TYPE_CASE(INT8) - TYPE_CASE(INT16) - TYPE_CASE(INT32) - TYPE_CASE(INT64) - TYPE_CASE(FLOAT) - TYPE_CASE(DOUBLE) - default: - break; - } - } + for (const auto& column : batch.columns()) { + ConvertColumnsToTensorVisitor visitor{out_values, *column->data()}; + DCHECK_OK(VisitTypeInline(*column->type(), &visitor)); } } -#undef TYPE_CASE - Result> RecordBatch::ToTensor(MemoryPool* pool) const { if (num_columns() == 0) { return Status::TypeError(