Skip to content

Commit

Permalink
Use visitor pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
AlenkaF committed Mar 21, 2024
1 parent 632f401 commit b5cf637
Showing 1 changed file with 28 additions and 42 deletions.
70 changes: 28 additions & 42 deletions cpp/src/arrow/record_batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
#include "arrow/type.h"
#include "arrow/util/iterator.h"
#include "arrow/util/logging.h"
#include "arrow/util/unreachable.h"
#include "arrow/util/vector.h"
#include "arrow/visit_scalar_inline.h"

namespace arrow {

Expand Down Expand Up @@ -248,58 +250,42 @@ Result<std::shared_ptr<StructArray>> RecordBatch::ToStructArray() const {
/*offset=*/0);
}

#define TYPE_CASE(type) \
case Type::type: { \
using T = typename TypeIdTraits<Type::type>::Type; \
using CType_in = typename TypeTraits<T>::CType; \
auto* in_values = batch.column(i)->data()->GetValues<CType_in>(1); \
for (int64_t i = 0; i < length; ++i) { \
*out_values++ = static_cast<CType>(*in_values++); \
} \
break; \
template <typename Out>
struct ConvertColumnsToTensorVisitor {
Out*& out_values;
const ArrayData& in_data;

template <typename T>
Status Visit(const T&) {
if constexpr (is_numeric(T::type_id)) {
using In = typename T::c_type;
auto in_values = ArraySpan(in_data).GetSpan<In>(1, in_data.length);

if constexpr (std::is_same_v<In, Out>) {
memcpy(out_values, in_values.data(), in_values.size_bytes());
out_values += in_values.size();
} else {
for (In in_value : in_values) {
*out_values++ = static_cast<Out>(in_value);
}
}
return Status::OK();
}
Unreachable();
}
};

template <typename DataType>
inline void ConvertColumnsToTensor(const RecordBatch& batch, uint8_t* out) {
using CType = typename arrow::TypeTraits<DataType>::CType;
auto* out_values = reinterpret_cast<CType*>(out);
int64_t length = batch.num_rows();

for (int i = 0; i < batch.num_columns(); ++i) {
// If the column is of the same type than resulting data type
if (TypeTraits<DataType>::type_singleton() == batch.column(i)->type()) {
const auto* in_values = batch.column(i)->data()->GetValues<CType>(1);

memcpy(out_values, in_values, sizeof(CType) * length);
out_values += length;
} else { // If the column is different type than resulting data type
switch (batch.column(i)->type_id()) {
case Type::HALF_FLOAT: {
auto* in_values = batch.column(i)->data()->GetValues<uint16_t>(1);
for (int64_t i = 0; i < length; ++i) {
*out_values++ = static_cast<CType>(*in_values++);
}
break;
}
TYPE_CASE(UINT8)
TYPE_CASE(UINT16)
TYPE_CASE(UINT32)
TYPE_CASE(UINT64)
TYPE_CASE(INT8)
TYPE_CASE(INT16)
TYPE_CASE(INT32)
TYPE_CASE(INT64)
TYPE_CASE(FLOAT)
TYPE_CASE(DOUBLE)
default:
break;
}
}
for (const auto& column : batch.columns()) {
ConvertColumnsToTensorVisitor<CType> visitor{out_values, *column->data()};
DCHECK_OK(VisitTypeInline(*column->type(), &visitor));
}
}

#undef TYPE_CASE

Result<std::shared_ptr<Tensor>> RecordBatch::ToTensor(MemoryPool* pool) const {
if (num_columns() == 0) {
return Status::TypeError(
Expand Down

0 comments on commit b5cf637

Please sign in to comment.