Skip to content

Commit

Permalink
scan optimization for filter applying in case simple chunks (ydb-plat…
Browse files Browse the repository at this point in the history
…form#12476)

Тест флаки
  • Loading branch information
ivanmorozov333 authored and zverevgeny committed Jan 8, 2025
1 parent bdc72dc commit 1ae3179
Show file tree
Hide file tree
Showing 15 changed files with 341 additions and 99 deletions.
172 changes: 125 additions & 47 deletions ydb/core/formats/arrow/arrow_filter.cpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
#include "arrow_filter.h"
#include "switch/switch_type.h"
#include "common/container.h"

#include "common/adapter.h"
#include "common/container.h"
#include "switch/switch_type.h"

#include <ydb/library/actors/core/log.h>
#include <ydb/library/yverify_stream/yverify_stream.h>

#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_primitive.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/chunked_array.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api_vector.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/record_batch.h>
#include <ydb/library/yverify_stream/yverify_stream.h>
#include <ydb/library/actors/core/log.h>

namespace NKikimr::NArrow {

#define Y_VERIFY_OK(status) Y_ABORT_UNLESS(status.ok(), "%s", status.ToString().c_str())

namespace {
enum class ECompareResult: i8 {
enum class ECompareResult : i8 {
LESS = -1,
BORDER = 0,
GREATER = 1
Expand Down Expand Up @@ -50,8 +52,7 @@ inline void UpdateCompare(const T& value, const T& border, ECompareResult& res)
}

template <typename TArray, typename T>
bool CompareImpl(const std::shared_ptr<arrow::Array>& column, const T& border,
std::vector<NArrow::ECompareResult>& rowsCmp) {
bool CompareImpl(const std::shared_ptr<arrow::Array>& column, const T& border, std::vector<NArrow::ECompareResult>& rowsCmp) {
bool hasBorder = false;
ECompareResult* res = &rowsCmp[0];
auto array = std::static_pointer_cast<TArray>(column);
Expand All @@ -64,8 +65,7 @@ bool CompareImpl(const std::shared_ptr<arrow::Array>& column, const T& border,
}

template <typename TArray, typename T>
bool CompareImpl(const std::shared_ptr<arrow::ChunkedArray>& column, const T& border,
std::vector<NArrow::ECompareResult>& rowsCmp) {
bool CompareImpl(const std::shared_ptr<arrow::ChunkedArray>& column, const T& border, std::vector<NArrow::ECompareResult>& rowsCmp) {
bool hasBorder = false;
ECompareResult* res = &rowsCmp[0];

Expand All @@ -82,8 +82,7 @@ bool CompareImpl(const std::shared_ptr<arrow::ChunkedArray>& column, const T& bo

/// @return true in case we have no borders in compare: no need for future keys, allow early exit
template <typename TArray>
bool Compare(const arrow::Datum& column, const std::shared_ptr<arrow::Array>& borderArray,
std::vector<NArrow::ECompareResult>& rowsCmp) {
bool Compare(const arrow::Datum& column, const std::shared_ptr<arrow::Array>& borderArray, std::vector<NArrow::ECompareResult>& rowsCmp) {
auto border = GetValue(std::static_pointer_cast<TArray>(borderArray), 0);

switch (column.kind()) {
Expand All @@ -98,8 +97,7 @@ bool Compare(const arrow::Datum& column, const std::shared_ptr<arrow::Array>& bo
return false;
}

bool SwitchCompare(const arrow::Datum& column, const std::shared_ptr<arrow::Array>& border,
std::vector<NArrow::ECompareResult>& rowsCmp) {
bool SwitchCompare(const arrow::Datum& column, const std::shared_ptr<arrow::Array>& border, std::vector<NArrow::ECompareResult>& rowsCmp) {
Y_ABORT_UNLESS(border->length() == 1);

// first time it's empty
Expand All @@ -111,12 +109,11 @@ bool SwitchCompare(const arrow::Datum& column, const std::shared_ptr<arrow::Arra
using TWrap = std::decay_t<decltype(type)>;
using TArray = typename arrow::TypeTraits<typename TWrap::T>::ArrayType;
return Compare<TArray>(column, border, rowsCmp);
});
});
}

template <typename T>
void CompositeCompare(std::shared_ptr<T> some, std::shared_ptr<arrow::RecordBatch> borderBatch,
std::vector<NArrow::ECompareResult>& rowsCmp) {
void CompositeCompare(std::shared_ptr<T> some, std::shared_ptr<arrow::RecordBatch> borderBatch, std::vector<NArrow::ECompareResult>& rowsCmp) {
auto key = borderBatch->schema()->fields();
Y_ABORT_UNLESS(key.size());

Expand All @@ -130,11 +127,61 @@ void CompositeCompare(std::shared_ptr<T> some, std::shared_ptr<arrow::RecordBatc
Y_ABORT_UNLESS(some->schema()->GetFieldByName(field->name())->type()->id() == typeId);

if (SwitchCompare(column, border, rowsCmp)) {
break; // early exit in case we have all rows compared: no borders, can omit key tail
break; // early exit in case we have all rows compared: no borders, can omit key tail
}
}
}

} // namespace

TColumnFilter::TSlicesIterator::TSlicesIterator(const TColumnFilter& owner, const std::optional<ui32> start, const std::optional<ui32> count)
: Owner(owner)
, StartIndex(start)
, Count(count) {
AFL_VERIFY(!!StartIndex == !!Count);
AFL_VERIFY(Owner.GetFilter().size());
if (StartIndex) {
AFL_VERIFY(*StartIndex + *Count <= owner.GetRecordsCountVerified())("start", *StartIndex)("count", *count)("size", owner.GetRecordsCount());
}
}

TColumnFilter::TApplyContext& TColumnFilter::TApplyContext::Slice(const ui32 start, const ui32 count) {
AFL_VERIFY(!StartPos && !Count);
StartPos = start;
Count = count;
return *this;
}

ui32 TColumnFilter::TSlicesIterator::GetSliceSize() const {
AFL_VERIFY(IsValid());
if (!StartIndex) {
return *CurrentIterator;
} else {
const ui32 startIndex = GetStartIndex();
const ui32 finishIndex = std::min<ui32>(CurrentStartIndex + *CurrentIterator, *StartIndex + *Count);
AFL_VERIFY(startIndex < finishIndex)("start", startIndex)("finish", finishIndex);
return finishIndex - startIndex;
}
}

void TColumnFilter::TSlicesIterator::Start() {
CurrentStartIndex = 0;
CurrentIsFiltered = Owner.GetStartValue();
CurrentIterator = Owner.GetFilter().begin();
if (StartIndex) {
while (IsValid() && CurrentStartIndex + *CurrentIterator < *StartIndex) {
AFL_VERIFY(Next());
}
AFL_VERIFY(IsValid());
}
}

bool TColumnFilter::TSlicesIterator::Next() {
AFL_VERIFY(IsValid());
CurrentIsFiltered = !CurrentIsFiltered;
CurrentStartIndex += *CurrentIterator;
++CurrentIterator;
return IsValid();
}

bool TColumnFilter::TIterator::Next(const ui32 size) {
Expand Down Expand Up @@ -193,7 +240,8 @@ TString TColumnFilter::TIterator::DebugString() const {
return sb;
}

std::shared_ptr<arrow::BooleanArray> TColumnFilter::BuildArrowFilter(const ui32 expectedSize, const std::optional<ui32> startPos, const std::optional<ui32> count) const {
std::shared_ptr<arrow::BooleanArray> TColumnFilter::BuildArrowFilter(
const ui32 expectedSize, const std::optional<ui32> startPos, const std::optional<ui32> count) const {
AFL_VERIFY(!!startPos == !!count);
auto& simpleFilter = BuildSimpleFilter();
arrow::BooleanBuilder builder;
Expand Down Expand Up @@ -230,7 +278,7 @@ bool TColumnFilter::IsTotalDenyFilter() const {
}

void TColumnFilter::Reset(const ui32 count) {
Count = 0;
RecordsCount = 0;
FilterPlain.reset();
Filter.clear();
Filter.reserve(count / 4);
Expand All @@ -240,13 +288,13 @@ void TColumnFilter::Add(const bool value, const ui32 count) {
if (!count) {
return;
}
if (Y_UNLIKELY(LastValue != value || !Count)) {
if (Y_UNLIKELY(LastValue != value || !RecordsCount)) {
Filter.emplace_back(count);
LastValue = value;
} else {
Filter.back() += count;
}
Count += count;
RecordsCount += count;
}

ui32 TColumnFilter::CrossSize(const ui32 s1, const ui32 f1, const ui32 s2, const ui32 f2) {
Expand All @@ -256,7 +304,8 @@ ui32 TColumnFilter::CrossSize(const ui32 s1, const ui32 f1, const ui32 s2, const
return f - s;
}

NKikimr::NArrow::TColumnFilter TColumnFilter::MakePredicateFilter(const arrow::Datum& datum, const arrow::Datum& border, ECompareType compareType) {
NKikimr::NArrow::TColumnFilter TColumnFilter::MakePredicateFilter(
const arrow::Datum& datum, const arrow::Datum& border, ECompareType compareType) {
std::vector<ECompareResult> cmps;

switch (datum.kind()) {
Expand Down Expand Up @@ -311,17 +360,19 @@ NKikimr::NArrow::TColumnFilter TColumnFilter::MakePredicateFilter(const arrow::D
}

template <class TData>
bool ApplyImpl(const TColumnFilter& filter, std::shared_ptr<TData>& batch, const std::optional<ui32> startPos, const std::optional<ui32> count) {
bool ApplyImpl(const TColumnFilter& filter, std::shared_ptr<TData>& batch, const TColumnFilter::TApplyContext& context) {
if (!batch || !batch->num_rows()) {
return false;
}
AFL_VERIFY(!!startPos == !!count);
if (!filter.IsEmpty()) {
if (startPos) {
AFL_VERIFY(filter.Size() >= *startPos + *count)("filter_size", filter.Size())("start", *startPos)("count", *count);
AFL_VERIFY(*count == (size_t)batch->num_rows())("count", *count)("batch_size", batch->num_rows());
if (context.HasSlice()) {
AFL_VERIFY(filter.GetRecordsCountVerified() >= *context.GetStartPos() + *context.GetCount())(
"filter_size", filter.GetRecordsCountVerified())(
"start", context.GetStartPos())("count", context.GetCount());
AFL_VERIFY(*context.GetCount() == (size_t)batch->num_rows())("count", context.GetCount())("batch_size", batch->num_rows());
} else {
AFL_VERIFY(filter.Size() == (size_t)batch->num_rows())("filter_size", filter.Size())("batch_size", batch->num_rows());
AFL_VERIFY(filter.GetRecordsCountVerified() == (size_t)batch->num_rows())("filter_size", filter.GetRecordsCountVerified())(
"batch_size", batch->num_rows());
}
}
if (filter.IsTotalDenyFilter()) {
Expand All @@ -331,20 +382,27 @@ bool ApplyImpl(const TColumnFilter& filter, std::shared_ptr<TData>& batch, const
if (filter.IsTotalAllowFilter()) {
return true;
}
batch = NAdapter::TDataBuilderPolicy<TData>::ApplyArrowFilter(batch, filter.BuildArrowFilter(batch->num_rows(), startPos, count));
if (context.GetTrySlices() && filter.GetFilter().size() * 10 < filter.GetRecordsCountVerified() &&
filter.GetRecordsCountVerified() < filter.GetFilteredCountVerified() * 50) {
batch =
NAdapter::TDataBuilderPolicy<TData>::ApplySlicesFilter(batch, filter.BuildSlicesIterator(context.GetStartPos(), context.GetCount()));
} else {
batch = NAdapter::TDataBuilderPolicy<TData>::ApplyArrowFilter(
batch, filter.BuildArrowFilter(batch->num_rows(), context.GetStartPos(), context.GetCount()));
}
return batch->num_rows();
}

bool TColumnFilter::Apply(std::shared_ptr<TGeneralContainer>& batch, const std::optional<ui32> startPos, const std::optional<ui32> count) const {
return ApplyImpl(*this, batch, startPos, count);
bool TColumnFilter::Apply(std::shared_ptr<TGeneralContainer>& batch, const TApplyContext& context) const {
return ApplyImpl(*this, batch, context);
}

bool TColumnFilter::Apply(std::shared_ptr<arrow::Table>& batch, const std::optional<ui32> startPos, const std::optional<ui32> count) const {
return ApplyImpl(*this, batch, startPos, count);
bool TColumnFilter::Apply(std::shared_ptr<arrow::Table>& batch, const TApplyContext& context) const {
return ApplyImpl(*this, batch, context);
}

bool TColumnFilter::Apply(std::shared_ptr<arrow::RecordBatch>& batch, const std::optional<ui32> startPos, const std::optional<ui32> count) const {
return ApplyImpl(*this, batch, startPos, count);
bool TColumnFilter::Apply(std::shared_ptr<arrow::RecordBatch>& batch, const TApplyContext& context) const {
return ApplyImpl(*this, batch, context);
}

void TColumnFilter::Apply(const ui32 expectedRecordsCount, std::vector<arrow::Datum*>& datums) const {
Expand Down Expand Up @@ -382,9 +440,9 @@ void TColumnFilter::Apply(const ui32 expectedRecordsCount, std::vector<arrow::Da

const std::vector<bool>& TColumnFilter::BuildSimpleFilter() const {
if (!FilterPlain) {
Y_ABORT_UNLESS(Count);
Y_ABORT_UNLESS(RecordsCount);
std::vector<bool> result;
result.resize(Count, true);
result.resize(RecordsCount, true);
bool currentValue = GetStartValue();
ui32 currentPosition = 0;
for (auto&& i : Filter) {
Expand Down Expand Up @@ -433,12 +491,11 @@ class TColumnFilter::TMergerImpl {
private:
const TColumnFilter& Filter1;
const TColumnFilter& Filter2;

public:
TMergerImpl(const TColumnFilter& filter1, const TColumnFilter& filter2)
: Filter1(filter1)
, Filter2(filter2)
{

, Filter2(filter2) {
}

template <class TMergePolicy>
Expand All @@ -450,7 +507,7 @@ class TColumnFilter::TMergerImpl {
} else if (Filter2.empty()) {
return TMergePolicy::MergeWithSimple(Filter1, Filter2.DefaultFilterValue);
} else {
Y_ABORT_UNLESS(Filter1.Count == Filter2.Count);
Y_ABORT_UNLESS(Filter1.RecordsCount == Filter2.RecordsCount);
auto it1 = Filter1.Filter.cbegin();
auto it2 = Filter2.Filter.cbegin();

Expand Down Expand Up @@ -495,11 +552,10 @@ class TColumnFilter::TMergerImpl {
TColumnFilter result = TColumnFilter::BuildAllowFilter();
std::swap(resultFilter, result.Filter);
std::swap(curCurrent, result.LastValue);
std::swap(count, result.Count);
std::swap(count, result.RecordsCount);
return result;
}
}

};

TColumnFilter TColumnFilter::And(const TColumnFilter& extFilter) const {
Expand Down Expand Up @@ -569,7 +625,7 @@ TColumnFilter TColumnFilter::CombineSequentialAnd(const TColumnFilter& extFilter
TColumnFilter result = TColumnFilter::BuildAllowFilter();
std::swap(resultFilter, result.Filter);
std::swap(curCurrent, result.LastValue);
std::swap(count, result.Count);
std::swap(count, result.RecordsCount);
return result;
}
}
Expand All @@ -580,18 +636,19 @@ TColumnFilter::TIterator TColumnFilter::GetIterator(const bool reverse, const ui
} else if (IsTotalDenyFilter()) {
return TIterator(reverse, expectedSize, false);
} else {
AFL_VERIFY(expectedSize == Size())("expected", expectedSize)("size", Size())("reverse", reverse);
AFL_VERIFY(expectedSize == GetRecordsCountVerified())("expected", expectedSize)("count", GetRecordsCountVerified())(
"reverse", reverse);
return TIterator(reverse, Filter, GetStartValue(reverse));
}
}

std::optional<ui32> TColumnFilter::GetFilteredCount() const {
if (!FilteredCount) {
if (IsTotalAllowFilter()) {
if (!Count) {
if (!RecordsCount) {
return {};
} else {
FilteredCount = Count;
FilteredCount = RecordsCount;
}
} else if (IsTotalDenyFilter()) {
FilteredCount = 0;
Expand All @@ -617,4 +674,25 @@ void TColumnFilter::Append(const TColumnFilter& filter) {
}
}

std::optional<ui32> TColumnFilter::GetRecordsCount() const {
if (Filter.size()) {
AFL_VERIFY(RecordsCount);
return RecordsCount;
} else {
return std::nullopt;
}
}

ui32 TColumnFilter::GetRecordsCountVerified() const {
AFL_VERIFY(Filter.size());
AFL_VERIFY(RecordsCount);
return RecordsCount;
}

ui32 TColumnFilter::GetFilteredCountVerified() const {
const std::optional<ui32> result = GetFilteredCount();
AFL_VERIFY(!!result);
return *result;
}

} // namespace NKikimr::NArrow
Loading

0 comments on commit 1ae3179

Please sign in to comment.