Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-186] backport to 1.1 branch #268

Merged
merged 2 commits into from
Apr 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2559,6 +2559,7 @@ class AvgByCountAction<DataType, CType, ResDataType, ResCType,
auto typed_type = std::dynamic_pointer_cast<arrow::Decimal128Type>(type);
auto typed_res_type = std::dynamic_pointer_cast<arrow::Decimal128Type>(res_type);
scale_ = typed_type->scale();
res_precision_ = typed_type->precision();
res_scale_ = typed_res_type->scale();
std::unique_ptr<arrow::ArrayBuilder> builder;
arrow::MakeBuilder(ctx_->memory_pool(), res_type, &builder);
Expand Down Expand Up @@ -2660,10 +2661,11 @@ class AvgByCountAction<DataType, CType, ResDataType, ResCType,
cache_sum_[i] = 0;
} else {
cache_validity_[i] = true;
if (res_scale_ > scale_) {
if (res_scale_ != scale_) {
cache_sum_[i] = cache_sum_[i].Rescale(scale_, res_scale_).ValueOrDie();
}
cache_sum_[i] /= cache_count_[i];
cache_sum_[i] =
divide(cache_sum_[i], res_precision_, res_scale_, cache_count_[i]);
}
}
cache_sum_.resize(length_);
Expand Down Expand Up @@ -2691,11 +2693,12 @@ class AvgByCountAction<DataType, CType, ResDataType, ResCType,
cache_sum_[i + offset] = 0;
} else {
cache_validity_[i + offset] = true;
if (res_scale_ > scale_) {
if (res_scale_ != scale_) {
cache_sum_[i + offset] =
cache_sum_[i + offset].Rescale(scale_, res_scale_).ValueOrDie();
}
cache_sum_[i + offset] /= cache_count_[i + offset];
cache_sum_[i + offset] = divide(cache_sum_[i + offset], res_precision_,
res_scale_, cache_count_[i + offset]);
}
}
for (uint64_t i = 0; i < res_length; i++) {
Expand Down Expand Up @@ -2724,6 +2727,7 @@ class AvgByCountAction<DataType, CType, ResDataType, ResCType,
int in_null_count_ = 0;
// result
int scale_;
int res_precision_;
int res_scale_;
std::vector<ResCType> cache_sum_;
std::vector<int64_t> cache_count_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,7 @@ class HashAggregateKernel::Impl {
int gp_idx = 0;
std::vector<std::shared_ptr<arrow::Array>> outputs;
for (auto action : action_impl_list_) {
// FIXME(): to work around NSE-241
action->Finish(offset_, 20000, &outputs);
action->Finish(offset_, batch_size_, &outputs);
}
if (outputs.size() > 0) {
out_length += outputs[0]->length();
Expand Down Expand Up @@ -917,8 +916,7 @@ class HashAggregateKernel::Impl {
int gp_idx = 0;
std::vector<std::shared_ptr<arrow::Array>> outputs;
for (auto action : action_impl_list_) {
// FIXME(): to work around NSE-241
action->Finish(offset_, 20000, &outputs);
action->Finish(offset_, batch_size_, &outputs);
}
if (outputs.size() > 0) {
out_length += outputs[0]->length();
Expand Down Expand Up @@ -1074,8 +1072,7 @@ class HashAggregateKernel::Impl {
int gp_idx = 0;
std::vector<std::shared_ptr<arrow::Array>> outputs;
for (auto action : action_impl_list_) {
// FIXME(): to work around NSE-241
action->Finish(offset_, 20000, &outputs);
action->Finish(offset_, batch_size_, &outputs);
}
if (outputs.size() > 0) {
out_length += outputs[0]->length();
Expand Down
7 changes: 7 additions & 0 deletions native-sql-engine/cpp/src/precompile/gandiva.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ arrow::Decimal128 divide(arrow::Decimal128 left, int32_t left_precision,
return arrow::Decimal128(out);
}

arrow::Decimal128 divide(const arrow::Decimal128& x, int32_t precision, int32_t scale,
int64_t y) {
gandiva::BasicDecimalScalar128 val(x, precision, scale);
arrow::BasicDecimal128 out = gandiva::decimalops::Divide(val, y);
return arrow::Decimal128(out);
}

// A comparison with a NaN always returns false even when comparing with itself.
// To get the same result as spark, we can regard NaN as big as Infinity when
// doing comparison.
Expand Down
11 changes: 10 additions & 1 deletion native-sql-engine/cpp/src/shuffle/splitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,11 @@ arrow::Status Splitter::CacheRecordBatch(int32_t partition_id, bool reset_buffer
auto& builder = partition_binary_builders_[binary_idx][partition_id];
if (reset_buffers) {
RETURN_NOT_OK(builder->Finish(&arrays[i]));
builder->Reset();
} else {
auto data_size = builder->value_data_length();
RETURN_NOT_OK(builder->Finish(&arrays[i]));
builder->Reset();
RETURN_NOT_OK(builder->Reserve(num_rows));
RETURN_NOT_OK(builder->ReserveData(data_size));
}
Expand All @@ -441,9 +443,11 @@ arrow::Status Splitter::CacheRecordBatch(int32_t partition_id, bool reset_buffer
partition_large_binary_builders_[large_binary_idx][partition_id];
if (reset_buffers) {
RETURN_NOT_OK(builder->Finish(&arrays[i]));
builder->Reset();
} else {
auto data_size = builder->value_data_length();
RETURN_NOT_OK(builder->Finish(&arrays[i]));
builder->Reset();
RETURN_NOT_OK(builder->Reserve(num_rows));
RETURN_NOT_OK(builder->ReserveData(data_size));
}
Expand Down Expand Up @@ -699,6 +703,9 @@ arrow::Status Splitter::DoSplit(const arrow::RecordBatch& rb) {
RETURN_NOT_OK(AllocatePartitionBuffers(pid, new_size));
} else { // not first allocate, spill
if (partition_id_cnt_[pid] > partition_buffer_size_[pid]) { // need reallocate?
// TODO(): CacheRecordBatch will try to reset builder buffer
// AllocatePartitionBuffers will then Reserve memory for builder based on last
// recordbatch, the logic on reservation size should be cleaned up
RETURN_NOT_OK(CacheRecordBatch(pid, true));
RETURN_NOT_OK(SpillPartition(pid));
RETURN_NOT_OK(AllocatePartitionBuffers(pid, new_size));
Expand Down Expand Up @@ -1047,6 +1054,7 @@ arrow::Status Splitter::AppendBinary(
offset_type length;
auto value = src_arr->GetValue(row, &length);
const auto& builder = dst_builders[partition_id_[row]];
RETURN_NOT_OK(builder->Reserve(1));
RETURN_NOT_OK(builder->ReserveData(length));
builder->UnsafeAppend(value, length);
}
Expand All @@ -1056,10 +1064,11 @@ arrow::Status Splitter::AppendBinary(
offset_type length;
auto value = src_arr->GetValue(row, &length);
const auto& builder = dst_builders[partition_id_[row]];
RETURN_NOT_OK(builder->Reserve(1));
RETURN_NOT_OK(builder->ReserveData(length));
builder->UnsafeAppend(value, length);
} else {
dst_builders[partition_id_[row]]->UnsafeAppendNull();
dst_builders[partition_id_[row]]->AppendNull();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ TEST(TestArrowCompute, AggregateTest) {
"[39]",
R"(["345.262397"])",
"[785]",
R"(["0.439824"])",
R"(["0.439825"])",
R"([39])",
R"([8.85288])",
R"([11113.3])"};
Expand Down Expand Up @@ -284,8 +284,8 @@ TEST(TestArrowCompute, GroupByAggregateTest) {
R"(["15.704202", "12.050089", "19.776600", "15.878089", "24.840018",
null, "28.101000", "22.136100", "16.008800", "26.676800", "164.090699"])",
R"([140, 20, 11, 89, 131, null, 57, 27, 10, 89, 211])",
R"(["0.1121728714", "0.6025044500", "1.7978727272", "0.1784054943", "0.1896184580",
null, "0.4930000000", "0.8198555555", "1.6008800000", "0.2997393258", "0.7776810379"])"};
R"(["0.1121728714", "0.6025044500", "1.7978727273", "0.1784054944", "0.1896184580",
null, "0.4930000000", "0.8198555556", "1.6008800000", "0.2997393258", "0.7776810379"])"};
auto res_sch = arrow::schema(ret_types);
MakeInputBatch(expected_result_string, res_sch, &expected_result);
if (aggr_result_iterator->HasNext()) {
Expand Down Expand Up @@ -425,8 +425,8 @@ TEST(TestArrowCompute, GroupByAggregateWSCGTest) {
R"(["15.704202", "12.050089", "19.776600", "15.878089", "24.840018",
null, "28.101000", "22.136100", "16.008800", "26.676800", "164.090699"])",
R"([140, 20, 11, 89, 131, null, 57, 27, 10, 89, 211])",
R"(["0.1121728714", "0.6025044500", "1.7978727272", "0.1784054943", "0.1896184580",
null, "0.4930000000", "0.8198555555", "1.6008800000", "0.2997393258", "0.7776810379"])"};
R"(["0.1121728714", "0.6025044500", "1.7978727273", "0.1784054944", "0.1896184580",
null, "0.4930000000", "0.8198555556", "1.6008800000", "0.2997393258", "0.7776810379"])"};
auto res_sch = arrow::schema(ret_types);
MakeInputBatch(expected_result_string, res_sch, &expected_result);
if (aggr_result_iterator->HasNext()) {
Expand Down
11 changes: 11 additions & 0 deletions native-sql-engine/cpp/src/tests/arrow_compute_test_precompile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ TEST(TestArrowCompute, ArithmeticDecimalTest) {
ASSERT_EQ(res, arrow::Decimal128("32342423.0129"));
res = arrow::Decimal128("-32342423.012875").Abs();
ASSERT_EQ(res, left);
// decimal divide int test
auto x = arrow::Decimal128("30.222215");
int32_t x_precision = 14;
int32_t x_scale = 6;
int64_t y = 8;
res = x / y;
// wrong result
ASSERT_EQ(res, arrow::Decimal128("3.777776"));
// correct result
res = divide(x, x_precision, x_scale, y);
ASSERT_EQ(res, arrow::Decimal128("3.777777"));
}

TEST(TestArrowCompute, ArithmeticComparisonTest) {
Expand Down
17 changes: 17 additions & 0 deletions native-sql-engine/cpp/src/third_party/gandiva/decimal_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,23 @@ BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x,
return result;
}

BasicDecimal128 Divide(const BasicDecimalScalar128& x, int64_t y) {
if (y == 0) {
throw std::runtime_error("divide by zero error");
}
BasicDecimal128 result;
BasicDecimal128 remainder;
auto status = x.value().Divide(y, &result, &remainder);
DCHECK_EQ(status, arrow::DecimalStatus::kSuccess);
// round-up
// returns 1 for positive and zero values, -1 for negative values.
int64_t y_sign = y < 0 ? -1 : 1;
if (BasicDecimal128::Abs(2 * remainder) >= BasicDecimal128::Abs(y)) {
result += (x.value().Sign() ^ y_sign) + 1;
}
return result;
}

BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
const BasicDecimalScalar128& y, int32_t out_precision,
int32_t out_scale, bool* overflow) {
Expand Down
3 changes: 3 additions & 0 deletions native-sql-engine/cpp/src/third_party/gandiva/decimal_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ arrow::BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x,
const BasicDecimalScalar128& y, int32_t out_precision,
int32_t out_scale, bool* overflow);

// Divide 'x'(decimal) by 'y'(int64_t), and return the result.
BasicDecimal128 Divide(const BasicDecimalScalar128& x, int64_t y);

/// Divide 'x' by 'y', and return the remainder.
arrow::BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
const BasicDecimalScalar128& y, int32_t out_precision,
Expand Down