From ce48a27b6deb22e3a7f75d31a8f2f5e86ab3e091 Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Wed, 31 May 2023 14:46:22 +0800 Subject: [PATCH] Support kPreceeding & kFollowing for window range frame type (#287) * Initial commit * Fix compile issue * Cherry pick PR 4510 * Fix issue in upstream PR * Fix bugs for unique sort key * Fix bugs for repeated sort key * Add more test cases * Fix int type issue * Handle null * Remove some commented code * Remove check null * Fix velox ut failure for rows frame * Format the code --- CMakeLists.txt | 1 + velox/exec/Window.cpp | 333 +++++++++++++++++- velox/exec/Window.h | 56 +++ .../lib/window/tests/WindowTestBase.cpp | 35 ++ .../lib/window/tests/WindowTestBase.h | 2 + .../prestosql/window/tests/NthValueTest.cpp | 7 + .../prestosql/window/tests/RankTest.cpp | 5 + .../window/tests/SimpleAggregatesTest.cpp | 79 +++++ velox/substrait/SubstraitToVeloxPlan.cpp | 25 +- .../SubstraitToVeloxPlanValidator.cpp | 2 + 10 files changed, 539 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bff48ecc5652a..92b1c83a6b80b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,7 @@ set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/CMake" ${CMAKE_MODULE_PATH}) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # set the project name project(velox) +add_definitions("-DNDEBUG") list(PREPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/CMake") diff --git a/velox/exec/Window.cpp b/velox/exec/Window.cpp index 302797b371bbb..4f780cd986b92 100644 --- a/velox/exec/Window.cpp +++ b/velox/exec/Window.cpp @@ -83,6 +83,8 @@ Window::Window( std::make_unique(inputColumns, inputType->children()); createWindowFunctions(windowNode, inputType); + + initRangeValuesMap(); } Window::WindowFrame Window::createWindowFrame( @@ -110,6 +112,17 @@ Window::WindowFrame Window::createWindowFrame( } }; + // If this is a k Range frame bound, then its evaluation requires that the + // order by key be a single column (to add or subtract the k range value + // from). + if (frame.type == core::WindowNode::WindowType::kRange && + (frame.startValue || frame.endValue)) { + VELOX_USER_CHECK_EQ( + sortKeyInfo_.size(), + 1, + "Window frame of type RANGE PRECEDING or FOLLOWING requires single sort item in ORDER BY."); + } + return WindowFrame( {frame.type, frame.startType, @@ -148,6 +161,25 @@ void Window::createWindowFunctions( } } +void Window::initRangeValuesMap() { + auto isKBoundFrame = [](core::WindowNode::BoundType boundType) -> bool { + return boundType == core::WindowNode::BoundType::kPreceding || + boundType == core::WindowNode::BoundType::kFollowing; + }; + + hasKRangeFrames_ = false; + for (const auto& frame : windowFrames_) { + if (frame.type == core::WindowNode::WindowType::kRange && + (isKBoundFrame(frame.startType) || isKBoundFrame(frame.endType))) { + hasKRangeFrames_ = true; + rangeValuesMap_.rangeType = outputType_->childAt(sortKeyInfo_[0].first); + rangeValuesMap_.rangeValues = + BaseVector::create(rangeValuesMap_.rangeType, 0, pool()); + break; + } + } +} + void Window::addInput(RowVectorPtr input) { inputRows_.resize(input->size()); @@ -275,6 +307,35 @@ void Window::noMoreInput() { createPeerAndFrameBuffers(); } +void Window::computeRangeValuesMap() { + auto peerCompare = [&](const char* lhs, const char* rhs) -> bool { + return compareRowsWithKeys(lhs, rhs, sortKeyInfo_); + }; + auto firstPartitionRow = partitionStartRows_[currentPartition_]; + auto lastPartitionRow = partitionStartRows_[currentPartition_ + 1] - 1; + auto numRows = lastPartitionRow - firstPartitionRow + 1; + rangeValuesMap_.rangeValues->resize(numRows); + rangeValuesMap_.rowIndices.resize(numRows); + + rangeValuesMap_.rowIndices[0] = 0; + int j = 1; + for (auto i = firstPartitionRow + 1; i <= lastPartitionRow; i++) { + // Here, we removed the below check code, in order to keep raw values. + // if (peerCompare(sortedRows_[i - 1], sortedRows_[i])) { + // The order by values are extracted from the Window partition which + // starts from row number 0 for the firstPartitionRow. So the index + // requires adjustment. + rangeValuesMap_.rowIndices[j++] = i - firstPartitionRow; + // } + } + + // If sort key is desc then reverse the rowIndices so that the range values + // are guaranteed ascending for the further lookup logic. + auto valueIndexesRange = folly::Range(rangeValuesMap_.rowIndices.data(), j); + windowPartition_->extractColumn( + sortKeyInfo_[0].first, valueIndexesRange, 0, rangeValuesMap_.rangeValues); +} + void Window::callResetPartition(vector_size_t partitionNumber) { partitionOffset_ = 0; auto partitionSize = partitionStartRows_[partitionNumber + 1] - @@ -285,6 +346,10 @@ void Window::callResetPartition(vector_size_t partitionNumber) { for (int i = 0; i < windowFunctions_.size(); i++) { windowFunctions_[i]->resetPartition(windowPartition_.get()); } + + if (hasKRangeFrames_) { + computeRangeValuesMap(); + } } void Window::updateKRowsFrameBounds( @@ -299,7 +364,17 @@ void Window::updateKRowsFrameBounds( auto constantOffset = frameArg.constant.value(); auto startValue = startRow + (isKPreceding ? -constantOffset : constantOffset) - firstPartitionRow; - std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue); + auto lastPartitionRow = partitionStartRows_[currentPartition_ + 1] - 1; + // TODO: check first partition boundary and validate the frame. + for (int i = 0; i < numRows; i++) { + if (startValue > lastPartitionRow) { + rawFrameBounds[i] = lastPartitionRow + 1; + } else { + rawFrameBounds[i] = startValue; + } + startValue++; + } + // std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue); } else { windowPartition_->extractColumn( frameArg.index, partitionOffset_, numRows, 0, frameArg.value); @@ -315,12 +390,174 @@ void Window::updateKRowsFrameBounds( // moves ahead. int precedingFactor = isKPreceding ? -1 : 1; for (auto i = 0; i < numRows; i++) { + // TOOD: check whether the value is inside [firstPartitionRow, + // lastPartitionRow]. rawFrameBounds[i] = (startRow + i) + vector_size_t(precedingFactor * offsets[i]) - firstPartitionRow; } } } +namespace { + +template +vector_size_t findIndex( + const T value, + vector_size_t leftBound, + vector_size_t rightBound, + const FlatVectorPtr& values, + bool findStart) { + vector_size_t originalRightBound = rightBound; + vector_size_t originalLeftBound = leftBound; + while (leftBound < rightBound) { + vector_size_t mid = round((leftBound + rightBound) / 2.0); + auto midValue = values->valueAt(mid); + if (value == midValue) { + return mid; + } + + if (value < midValue) { + rightBound = mid - 1; + } else { + leftBound = mid + 1; + } + } + + // The value is not found but leftBound == rightBound at this point. + // This could be a value which is the least number greater than + // or the largest number less than value. + // The semantics of this function are to always return the smallest larger + // value (or rightBound if end of range). + if (findStart) { + if (value <= values->valueAt(rightBound)) { + // return std::max(originalLeftBound, rightBound); + return rightBound; + } + return std::min(originalRightBound, rightBound + 1); + } + if (value < values->valueAt(rightBound)) { + return std::max(originalLeftBound, rightBound - 1); + } + // std::max(originalLeftBound, rightBound)? + return std::min(originalRightBound, rightBound); +} + +} // namespace + +// TODO: unify into one function. +template +inline vector_size_t Window::kRangeStartBoundSearch( + const T value, + vector_size_t leftBound, + vector_size_t rightBound, + const FlatVectorPtr& valuesVector, + const vector_size_t* rawPeerStarts) { + auto index = findIndex(value, leftBound, rightBound, valuesVector, true); + // Since this is a kPreceding bound it includes the row at the index. + return rangeValuesMap_.rowIndices[rawPeerStarts[index]]; +} + +// TODO: lastRightBoundRow looks useless. +template +vector_size_t Window::kRangeEndBoundSearch( + const T value, + vector_size_t leftBound, + vector_size_t rightBound, + vector_size_t lastRightBoundRow, + const FlatVectorPtr& valuesVector, + const vector_size_t* rawPeerEnds) { + auto index = findIndex(value, leftBound, rightBound, valuesVector, false); + return rangeValuesMap_.rowIndices[rawPeerEnds[index]]; +} + +template +void Window::updateKRangeFrameBounds( + bool isKPreceding, + bool isStartBound, + const FrameChannelArg& frameArg, + vector_size_t numRows, + vector_size_t* rawFrameBounds, + const vector_size_t* rawPeerStarts, + const vector_size_t* rawPeerEnds) { + using NativeType = typename TypeTraits::NativeType; + // Extract the order by key column to calculate the range values for the frame + // boundaries. + std::shared_ptr sortKeyType = + outputType_->childAt(sortKeyInfo_[0].first); + auto orderByValues = BaseVector::create(sortKeyType, numRows, pool()); + windowPartition_->extractColumn( + sortKeyInfo_[0].first, partitionOffset_, numRows, 0, orderByValues); + auto* rangeValuesFlatVector = orderByValues->asFlatVector(); + auto* rawRangeValues = rangeValuesFlatVector->mutableRawValues(); + + if (frameArg.index == kConstantChannel) { + auto constantOffset = frameArg.constant.value(); + constantOffset = isKPreceding ? -constantOffset : constantOffset; + for (int i = 0; i < numRows; i++) { + rawRangeValues[i] = rangeValuesFlatVector->valueAt(i) + constantOffset; + } + } else { + windowPartition_->extractColumn( + frameArg.index, partitionOffset_, numRows, 0, frameArg.value); + auto offsets = frameArg.value->values()->as(); + for (auto i = 0; i < numRows; i++) { + VELOX_USER_CHECK( + !frameArg.value->isNullAt(i), "k in frame bounds cannot be null"); + VELOX_USER_CHECK_GE( + offsets[i], 1, "k in frame bounds must be at least 1"); + } + + auto precedingFactor = isKPreceding ? -1 : 1; + for (auto i = 0; i < numRows; i++) { + rawRangeValues[i] = rangeValuesFlatVector->valueAt(i) + + vector_size_t(precedingFactor * offsets[i]); + } + } + + // Set the frame bounds from looking up the rangeValues index. + auto leftBound = 0; + auto rightBound = rangeValuesMap_.rowIndices.size() - 1; + auto lastPartitionRow = partitionStartRows_[currentPartition_ + 1] - 1; + auto rangeIndexValues = std::dynamic_pointer_cast>( + rangeValuesMap_.rangeValues); + if (isStartBound) { + for (auto i = 0; i < numRows; i++) { + // Handle null. + // Different with duckDB result. May need to separate the handling for + // spark & presto. + if (rangeValuesFlatVector->mayHaveNulls() && + rangeValuesFlatVector->isNullAt(i)) { + rawFrameBounds[i] = i; + continue; + } + rawFrameBounds[i] = kRangeStartBoundSearch( + rawRangeValues[i], + leftBound, + rightBound, + rangeIndexValues, + rawPeerStarts); + } + } else { + for (auto i = 0; i < numRows; i++) { + // Handle null. + // Different with duckDB result. May need to separate the handling for + // spark & presto. + if (rangeValuesFlatVector->mayHaveNulls() && + rangeValuesFlatVector->isNullAt(i)) { + rawFrameBounds[i] = i; + continue; + } + rawFrameBounds[i] = kRangeEndBoundSearch( + rawRangeValues[i], + leftBound, + rightBound, + lastPartitionRow, + rangeIndexValues, + rawPeerEnds); + } + } +} + void Window::updateFrameBounds( const WindowFrame& windowFrame, const bool isStartBound, @@ -365,7 +602,52 @@ void Window::updateFrameBounds( updateKRowsFrameBounds( true, frameArg.value(), startRow, numRows, rawFrameBounds); } else { - VELOX_NYI("k preceding frame is only supported in ROWS mode"); + // Sort key type. + auto sortKeyTypePtr = outputType_->childAt(sortKeyInfo_[0].first); + switch (sortKeyTypePtr->kind()) { + case TypeKind::TINYINT: + updateKRangeFrameBounds( + true, + isStartBound, + frameArg.value(), + numRows, + rawFrameBounds, + rawPeerStarts, + rawPeerEnds); + break; + case TypeKind::SMALLINT: + updateKRangeFrameBounds( + true, + isStartBound, + frameArg.value(), + numRows, + rawFrameBounds, + rawPeerStarts, + rawPeerEnds); + break; + case TypeKind::INTEGER: + updateKRangeFrameBounds( + true, + isStartBound, + frameArg.value(), + numRows, + rawFrameBounds, + rawPeerStarts, + rawPeerEnds); + break; + case TypeKind::BIGINT: + updateKRangeFrameBounds( + true, + isStartBound, + frameArg.value(), + numRows, + rawFrameBounds, + rawPeerStarts, + rawPeerEnds); + break; + default: + VELOX_USER_FAIL("Not supported type for sort key!"); + } } break; } @@ -374,7 +656,52 @@ void Window::updateFrameBounds( updateKRowsFrameBounds( false, frameArg.value(), startRow, numRows, rawFrameBounds); } else { - VELOX_NYI("k following frame is only supported in ROWS mode"); + // Sort key type. + auto sortKeyTypePtr = outputType_->childAt(sortKeyInfo_[0].first); + switch (sortKeyTypePtr->kind()) { + case TypeKind::TINYINT: + updateKRangeFrameBounds( + false, + isStartBound, + frameArg.value(), + numRows, + rawFrameBounds, + rawPeerStarts, + rawPeerEnds); + break; + case TypeKind::SMALLINT: + updateKRangeFrameBounds( + false, + isStartBound, + frameArg.value(), + numRows, + rawFrameBounds, + rawPeerStarts, + rawPeerEnds); + break; + case TypeKind::INTEGER: + updateKRangeFrameBounds( + false, + isStartBound, + frameArg.value(), + numRows, + rawFrameBounds, + rawPeerStarts, + rawPeerEnds); + break; + case TypeKind::BIGINT: + updateKRangeFrameBounds( + false, + isStartBound, + frameArg.value(), + numRows, + rawFrameBounds, + rawPeerStarts, + rawPeerEnds); + break; + default: + VELOX_USER_FAIL("Not supported type for sort key!"); + } } break; } diff --git a/velox/exec/Window.h b/velox/exec/Window.h index 916b01698750a..7de1a4e556a1b 100644 --- a/velox/exec/Window.h +++ b/velox/exec/Window.h @@ -86,6 +86,9 @@ class Window : public Operator { const std::shared_ptr& windowNode, const RowTypePtr& inputType); + // Helper function to initialize range values map for k Range frames. + void initRangeValuesMap(); + // Helper function to create the buffers for peer and frame // row indices to send in window function apply invocations. void createPeerAndFrameBuffers(); @@ -110,6 +113,11 @@ class Window : public Operator { // all WindowFunctions. void callResetPartition(vector_size_t partitionNumber); + // For k Range frames an auxiliary structure used to look up the index + // of frame values is required. This function computes that structure for + // each partition of rows. + void computeRangeValuesMap(); + // Helper method to call WindowFunction::apply to all the rows // of a partition between startRow and endRow. The outputs // will be written to the vectors in windowFunctionOutputs @@ -148,6 +156,16 @@ class Window : public Operator { vector_size_t numRows, vector_size_t* rawFrameBounds); + template + void updateKRangeFrameBounds( + bool isKPreceding, + bool isStartBound, + const FrameChannelArg& frameArg, + vector_size_t numRows, + vector_size_t* rawFrameBounds, + const vector_size_t* rawPeerStarts, + const vector_size_t* rawPeerEnds); + // Helper function to update frame bounds. void updateFrameBounds( const WindowFrame& windowFrame, @@ -158,6 +176,23 @@ class Window : public Operator { const vector_size_t* rawPeerEnds, vector_size_t* rawFrameBounds); + template + vector_size_t kRangeStartBoundSearch( + const T value, + vector_size_t leftBound, + vector_size_t rightBound, + const FlatVectorPtr& valuesVector, + const vector_size_t* rawPeerStarts); + + template + vector_size_t kRangeEndBoundSearch( + const T value, + vector_size_t leftBound, + vector_size_t rightBound, + vector_size_t lastRightBoundRow, + const FlatVectorPtr& valuesVector, + const vector_size_t* rawPeerEnds); + bool finished_ = false; const vector_size_t numInputColumns_; @@ -243,6 +278,27 @@ class Window : public Operator { // There is one SelectivityVector per window function. std::vector validFrames_; + // When computing k Range frames, the range value for the frame index needs + // to be mapped to the partition row for the value. + // This is an auxiliary structure to materialize a mapping from + // range value -> row index (in RowContainer) for that purpose. + // It uses a vector of the ordered range values and another vector of the + // corresponding row indices. Ideally a binary search + // tree or B-tree index (especially if the data is spilled to disk) should be + // used. + struct RangeValuesMap { + TypePtr rangeType; + // The range values appear in sorted order in this vector. + VectorPtr rangeValues; + // TODO (Make this a BufferPtr so that it can be allocated in the + // MemoryPool) ? + std::vector rowIndices; + }; + RangeValuesMap rangeValuesMap_; + + // The above mapping is built only if required for k range frames. + bool hasKRangeFrames_; + // Number of rows output from the WindowOperator so far. The rows // are output in the same order of the pointers in sortedRows. This // value is updated as the WindowFunction::apply() function is diff --git a/velox/functions/lib/window/tests/WindowTestBase.cpp b/velox/functions/lib/window/tests/WindowTestBase.cpp index 5083b72c56f01..763f2e955c652 100644 --- a/velox/functions/lib/window/tests/WindowTestBase.cpp +++ b/velox/functions/lib/window/tests/WindowTestBase.cpp @@ -129,6 +129,41 @@ void WindowTestBase::testWindowFunction( } } +void WindowTestBase::testKRangeFrames(const std::string& function) { + // The current support for k Range frames is limited to ascending sort + // orders without null values. Frames clauses generating empty frames + // are also not supported. + + // For deterministic results its expected that rows have a fixed ordering + // in the partition so that the range frames are predictable. So the + // input table. + vector_size_t size = 100; + + auto vectors = makeRowVector({ + makeFlatVector(size, [](auto row) { return row % 10; }), + makeFlatVector(size, [](auto row) { return row; }), + makeFlatVector(size, [](auto row) { return row % 7 + 1; }), + makeFlatVector(size, [](auto row) { return row % 4 + 1; }), + }); + + const std::string overClause = "partition by c0 order by c1"; + const std::vector kRangeFrames = { + "range between 5 preceding and current row", + "range between current row and 5 following", + "range between 5 preceding and 5 following", + "range between unbounded preceding and 5 following", + "range between 5 preceding and unbounded following", + + "range between c3 preceding and current row", + "range between current row and c3 following", + "range between c2 preceding and c3 following", + "range between unbounded preceding and c3 following", + "range between c3 preceding and unbounded following", + }; + + testWindowFunction({vectors}, function, {overClause}, kRangeFrames); +} + void WindowTestBase::assertWindowFunctionError( const std::vector& input, const std::string& function, diff --git a/velox/functions/lib/window/tests/WindowTestBase.h b/velox/functions/lib/window/tests/WindowTestBase.h index 6bf4f1f58ef99..c96cc824fef54 100644 --- a/velox/functions/lib/window/tests/WindowTestBase.h +++ b/velox/functions/lib/window/tests/WindowTestBase.h @@ -155,6 +155,8 @@ class WindowTestBase : public exec::test::OperatorTestBase { const std::vector& overClauses, const std::vector& frameClauses = {""}); + void testKRangeFrames(const std::string& function); + /// This function tests the SQL query for the window function and overClause /// combination with the input RowVectors. It is expected that query execution /// will throw an exception with the errorMessage specified. diff --git a/velox/functions/prestosql/window/tests/NthValueTest.cpp b/velox/functions/prestosql/window/tests/NthValueTest.cpp index 443f6742ee76c..8616f878402af 100644 --- a/velox/functions/prestosql/window/tests/NthValueTest.cpp +++ b/velox/functions/prestosql/window/tests/NthValueTest.cpp @@ -205,6 +205,13 @@ TEST_F(NthValueTest, nullOffsets) { {vectors}, "nth_value(c0, c2)", kOverClauses); } +TEST_F(NthValueTest, kRangeFrames) { + testKRangeFrames("nth_value(c2, 1)"); + testKRangeFrames("nth_value(c2, 3)"); + testKRangeFrames("nth_value(c2, 5)"); + // testKRangeFrames("nth_value(c2, c3)"); +} + TEST_F(NthValueTest, invalidOffsets) { vector_size_t size = 20; diff --git a/velox/functions/prestosql/window/tests/RankTest.cpp b/velox/functions/prestosql/window/tests/RankTest.cpp index 6ce25303552ec..b24f0924024a0 100644 --- a/velox/functions/prestosql/window/tests/RankTest.cpp +++ b/velox/functions/prestosql/window/tests/RankTest.cpp @@ -103,6 +103,11 @@ TEST_P(RankTest, randomInput) { testWindowFunction({makeRandomInputVector(20), makeRandomInputVector(30)}); } +// Tests function with a randomly generated input dataset. +TEST_P(RankTest, rangeFrames) { + testKRangeFrames(function_); +} + // Run above tests for all combinations of rank function and over clauses. VELOX_INSTANTIATE_TEST_SUITE_P( RankTestInstantiation, diff --git a/velox/functions/prestosql/window/tests/SimpleAggregatesTest.cpp b/velox/functions/prestosql/window/tests/SimpleAggregatesTest.cpp index 2ed8f7a4a7e59..edd7fcea341c5 100644 --- a/velox/functions/prestosql/window/tests/SimpleAggregatesTest.cpp +++ b/velox/functions/prestosql/window/tests/SimpleAggregatesTest.cpp @@ -105,6 +105,11 @@ TEST_P(SimpleAggregatesTest, randomInput) { testWindowFunction({makeRandomInputVector(50)}); } +// Tests function with a randomly generated input dataset. +TEST_P(SimpleAggregatesTest, rangeFrames) { + testKRangeFrames(function_); +} + // Instantiate all the above tests for each combination of aggregate function // and over clause. VELOX_INSTANTIATE_TEST_SUITE_P( @@ -128,5 +133,79 @@ TEST_F(StringAggregatesTest, nonFixedWidthAggregate) { testWindowFunction(input, "max(c2)", kOverClauses); } +class KPreceedingFollowingTest : public WindowTestBase {}; + +TEST_F(KPreceedingFollowingTest, rangeFrames1) { + auto vectors = makeRowVector({ + makeFlatVector({1, 1, 2147483650, 3, 2, 2147483650}), + makeFlatVector({"1", "1", "1", "2", "1", "2"}), + }); + + const std::string overClause = "partition by c1 order by c0"; + const std::vector kRangeFrames1 = { + "range between current row and 2147483648 following", + }; + testWindowFunction({vectors}, "count(c0)", {overClause}, kRangeFrames1); + + const std::vector kRangeFrames2 = { + "range between 2147483648 preceding and current row", + }; + testWindowFunction({vectors}, "count(c0)", {overClause}, kRangeFrames2); +} + +TEST_F(KPreceedingFollowingTest, rangeFrames2) { + const std::vector vectors = { + makeRowVector( + {makeFlatVector({5, 6, 8, 9, 10, 2, 8, 9, 3}), + makeFlatVector( + {"1", "1", "1", "1", "1", "2", "2", "2", "2"})}), + // Has repeated sort key. + makeRowVector( + {makeFlatVector({5, 5, 3, 2, 8}), + makeFlatVector({"1", "1", "1", "2", "1"})}), + makeRowVector( + {makeFlatVector({5, 5, 4, 6, 3, 2, 8, 9, 9}), + makeFlatVector( + {"1", "1", "2", "2", "1", "2", "1", "1", "2"})}), + makeRowVector( + {makeFlatVector({5, 5, 4, 6, 3, 2}), + makeFlatVector({"1", "2", "2", "2", "1", "2"})}), + // Uses int32 for sort column. + makeRowVector( + {makeFlatVector({5, 5, 4, 6, 3, 2}), + makeFlatVector({"1", "2", "2", "2", "1", "2"})}), + }; + + const std::string overClause = "partition by c1 order by c0"; + const std::vector kRangeFrames = { + "range between unbounded preceding and 1 following", + "range between unbounded preceding and 2 following", + "range between unbounded preceding and 3 following", + "range between 1 preceding and unbounded following", + "range between 2 preceding and unbounded following", + "range between 3 preceding and unbounded following", + "range between 1 preceding and 3 following", + "range between 3 preceding and 1 following", + "range between 2 preceding and 2 following"}; + for (int i = 0; i < vectors.size(); i++) { + testWindowFunction({vectors[i]}, "avg(c0)", {overClause}, kRangeFrames); + testWindowFunction({vectors[i]}, "sum(c0)", {overClause}, kRangeFrames); + testWindowFunction({vectors[i]}, "count(c0)", {overClause}, kRangeFrames); + } +} + +TEST_F(KPreceedingFollowingTest, rowsFrames) { + auto vectors = makeRowVector({ + makeFlatVector({1, 1, 2147483650, 3, 2, 2147483650}), + makeFlatVector({"1", "1", "1", "2", "1", "2"}), + }); + + const std::string overClause = "partition by c1 order by c0"; + const std::vector kRangeFrames = { + "rows between current row and 2147483647 following", + }; + testWindowFunction({vectors}, "count(c0)", {overClause}, kRangeFrames); +} + }; // namespace }; // namespace facebook::velox::window::test diff --git a/velox/substrait/SubstraitToVeloxPlan.cpp b/velox/substrait/SubstraitToVeloxPlan.cpp index 4ebca3da1844b..cc9ca71c824eb 100644 --- a/velox/substrait/SubstraitToVeloxPlan.cpp +++ b/velox/substrait/SubstraitToVeloxPlan.cpp @@ -539,7 +539,6 @@ const core::WindowNode::Frame createWindowFrame( frame.type = core::WindowNode::WindowType::kRows; break; case ::substrait::WindowType::RANGE: - frame.type = core::WindowNode::WindowType::kRange; break; default: @@ -557,14 +556,34 @@ const core::WindowNode::Frame createWindowFrame( return core::WindowNode::BoundType::kUnboundedFollowing; } else if (boundType.has_unbounded_preceding()) { return core::WindowNode::BoundType::kUnboundedPreceding; + } else if (boundType.has_following()) { + return core::WindowNode::BoundType::kFollowing; + } else if (boundType.has_preceding()) { + return core::WindowNode::BoundType::kPreceding; } else { VELOX_FAIL("The BoundType is not supported."); } }; frame.startType = boundTypeConversion(lower_bound); - frame.startValue = nullptr; + switch (frame.startType) { + case core::WindowNode::BoundType::kPreceding: + // TODO: support non-literal expression. + frame.startValue = std::make_shared( + BIGINT(), variant(lower_bound.preceding().offset())); + break; + default: + frame.startValue = nullptr; + } frame.endType = boundTypeConversion(upper_bound); - frame.endValue = nullptr; + switch (frame.endType) { + // TODO: support non-literal expression. + case core::WindowNode::BoundType::kFollowing: + frame.endValue = std::make_shared( + BIGINT(), variant(upper_bound.following().offset())); + break; + default: + frame.endValue = nullptr; + } return frame; } diff --git a/velox/substrait/SubstraitToVeloxPlanValidator.cpp b/velox/substrait/SubstraitToVeloxPlanValidator.cpp index 6664e40daabee..f3376e43fe99b 100644 --- a/velox/substrait/SubstraitToVeloxPlanValidator.cpp +++ b/velox/substrait/SubstraitToVeloxPlanValidator.cpp @@ -329,6 +329,8 @@ bool validateBoundType(::substrait::Expression_WindowFunction_Bound boundType) { case ::substrait::Expression_WindowFunction_Bound::kUnboundedFollowing: case ::substrait::Expression_WindowFunction_Bound::kUnboundedPreceding: case ::substrait::Expression_WindowFunction_Bound::kCurrentRow: + case ::substrait::Expression_WindowFunction_Bound::kFollowing: + case ::substrait::Expression_WindowFunction_Bound::kPreceding: break; default: std::cout << "The Bound Type is not supported. "