diff --git a/velox/functions/prestosql/window/Rank.cpp b/velox/functions/prestosql/window/Rank.cpp index 6ac498a67505..8c0168617b05 100644 --- a/velox/functions/prestosql/window/Rank.cpp +++ b/velox/functions/prestosql/window/Rank.cpp @@ -107,14 +107,14 @@ void registerRankInternal( const std::vector& /*args*/, const TypePtr& resultType, velox::memory::MemoryPool* /*pool*/, - HashStringAllocator* /*stringAllocator*/) - -> std::unique_ptr { + HashStringAllocator * + /*stringAllocator*/) -> std::unique_ptr { return std::make_unique>(resultType); }); } void registerRank(const std::string& name) { - registerRankInternal(name, "bigint"); + registerRankInternal(name, "integer"); } void registerDenseRank(const std::string& name) { registerRankInternal(name, "bigint"); diff --git a/velox/functions/prestosql/window/RowNumber.cpp b/velox/functions/prestosql/window/RowNumber.cpp index 2b1148cc40e7..25d2687b5326 100644 --- a/velox/functions/prestosql/window/RowNumber.cpp +++ b/velox/functions/prestosql/window/RowNumber.cpp @@ -40,7 +40,7 @@ class RowNumberFunction : public exec::WindowFunction { vector_size_t resultOffset, const VectorPtr& result) override { int numRows = peerGroupStarts->size() / sizeof(vector_size_t); - auto* rawValues = result->asFlatVector()->mutableRawValues(); + auto* rawValues = result->asFlatVector()->mutableRawValues(); for (int i = 0; i < numRows; i++) { rawValues[resultOffset + i] = rowNumber_++; } @@ -68,8 +68,8 @@ void registerRowNumber(const std::string& name) { const std::vector& /*args*/, const TypePtr& /*resultType*/, velox::memory::MemoryPool* /*pool*/, - HashStringAllocator* /*stringAllocator*/) - -> std::unique_ptr { + HashStringAllocator * + /*stringAllocator*/) -> std::unique_ptr { return std::make_unique(); }); } diff --git a/velox/substrait/SubstraitToVeloxPlan.cpp b/velox/substrait/SubstraitToVeloxPlan.cpp index b660bb3fff8a..7c4b97bf3fa4 100644 --- a/velox/substrait/SubstraitToVeloxPlan.cpp +++ b/velox/substrait/SubstraitToVeloxPlan.cpp @@ -571,6 +571,155 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( childNode); } +const core::WindowNode::Frame createWindowFrame( + const ::substrait::Expression_WindowFunction_Bound& lower_bound, + const ::substrait::Expression_WindowFunction_Bound& upper_bound, + const ::substrait::WindowType& type) { + core::WindowNode::Frame frame; + switch (type) { + case ::substrait::WindowType::ROWS: + frame.type = core::WindowNode::WindowType::kRows; + break; + case ::substrait::WindowType::RANGE: + + frame.type = core::WindowNode::WindowType::kRange; + break; + default: + VELOX_FAIL( + "the window type only support ROWS and RANGE, and the input type is ", + type); + } + + auto boundTypeConversion = + [](::substrait::Expression_WindowFunction_Bound boundType) + -> core::WindowNode::BoundType { + if (boundType.has_current_row()) { + return core::WindowNode::BoundType::kCurrentRow; + } else if (boundType.has_unbounded_following()) { + return core::WindowNode::BoundType::kUnboundedFollowing; + } else if (boundType.has_unbounded_preceding()) { + return core::WindowNode::BoundType::kUnboundedPreceding; + } else { + VELOX_FAIL("The BoundType is not supported."); + } + }; + frame.startType = boundTypeConversion(lower_bound); + frame.startValue = nullptr; + frame.endType = boundTypeConversion(upper_bound); + frame.endValue = nullptr; + return frame; +} + +core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( + const ::substrait::WindowRel& windowRel) { + core::PlanNodePtr childNode; + if (windowRel.has_input()) { + childNode = toVeloxPlan(windowRel.input()); + } else { + VELOX_FAIL("Child Rel is expected in WindowRel."); + } + + const auto& inputType = childNode->outputType(); + + // Parse measures and get the window expressions. + // Each measure represents one window expression. + bool ignoreNullKeys = false; + std::vector windowNodeFunctions; + std::vector windowColumnNames; + + windowNodeFunctions.reserve(windowRel.measures().size()); + for (const auto& smea : windowRel.measures()) { + const auto& windowFunction = smea.measure(); + std::string funcName = subParser_->findVeloxFunction( + functionMap_, windowFunction.function_reference()); + std::vector> windowParams; + windowParams.reserve(windowFunction.arguments().size()); + for (const auto& arg : windowFunction.arguments()) { + windowParams.emplace_back( + exprConverter_->toVeloxExpr(arg.value(), inputType)); + } + auto windowVeloxType = + toVeloxType(subParser_->parseType(windowFunction.output_type())->type); + auto windowCall = std::make_shared( + windowVeloxType, std::move(windowParams), funcName); + auto upperBound = windowFunction.upper_bound(); + auto lowerBound = windowFunction.lower_bound(); + auto type = windowFunction.window_type(); + + windowColumnNames.push_back(windowFunction.column_name()); + + windowNodeFunctions.push_back( + {std::move(windowCall), + createWindowFrame(lowerBound, upperBound, type), + ignoreNullKeys}); + } + + // Construct partitionKeys + std::vector partitionKeys; + const auto& partitions = windowRel.partition_expressions(); + partitionKeys.reserve(partitions.size()); + for (const auto& partition : partitions) { + auto expression = exprConverter_->toVeloxExpr(partition, inputType); + auto expr_field = + dynamic_cast(expression.get()); + VELOX_CHECK( + expr_field != nullptr, + " the partition key in Window Operator only support field") + + partitionKeys.emplace_back( + std::dynamic_pointer_cast( + expression)); + } + + std::vector sortingKeys; + std::vector sortingOrders; + + const auto& sorts = windowRel.sorts(); + sortingKeys.reserve(sorts.size()); + sortingOrders.reserve(sorts.size()); + + for (const auto& sort : sorts) { + switch (sort.direction()) { + case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST: + sortingOrders.emplace_back(core::kAscNullsFirst); + break; + case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST: + sortingOrders.emplace_back(core::kAscNullsLast); + break; + case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST: + sortingOrders.emplace_back(core::kDescNullsFirst); + break; + case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST: + sortingOrders.emplace_back(core::kDescNullsLast); + break; + default: + VELOX_FAIL("Sort direction is not support in WindowRel"); + } + + if (sort.has_expr()) { + auto expression = exprConverter_->toVeloxExpr(sort.expr(), inputType); + auto expr_field = + dynamic_cast(expression.get()); + VELOX_CHECK( + expr_field != nullptr, + " the sorting key in Window Operator only support field") + + sortingKeys.emplace_back( + std::dynamic_pointer_cast( + expression)); + } + } + + return std::make_shared( + nextPlanNodeId(), + partitionKeys, + sortingKeys, + sortingOrders, + windowColumnNames, + windowNodeFunctions, + childNode); +} + core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( const ::substrait::SortRel& sortRel) { auto childNode = convertSingleInput<::substrait::SortRel>(sortRel); @@ -970,6 +1119,9 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( if (sRel.has_fetch()) { return toVeloxPlan(sRel.fetch()); } + if (sRel.has_window()) { + return toVeloxPlan(sRel.window()); + } VELOX_NYI("Substrait conversion not supported for Rel."); } diff --git a/velox/substrait/SubstraitToVeloxPlan.h b/velox/substrait/SubstraitToVeloxPlan.h index 288e27ffcd76..13fe579081e9 100644 --- a/velox/substrait/SubstraitToVeloxPlan.h +++ b/velox/substrait/SubstraitToVeloxPlan.h @@ -52,6 +52,9 @@ class SubstraitVeloxPlanConverter { /// Used to convert Substrait ExpandRel into Velox PlanNode. core::PlanNodePtr toVeloxPlan(const ::substrait::ExpandRel& sExpand); + /// Used to convert Substrait SortRel into Velox PlanNode. + core::PlanNodePtr toVeloxPlan(const ::substrait::WindowRel& sWindow); + /// Used to convert Substrait JoinRel into Velox PlanNode. core::PlanNodePtr toVeloxPlan(const ::substrait::JoinRel& sJoin); diff --git a/velox/substrait/SubstraitToVeloxPlanValidator.cpp b/velox/substrait/SubstraitToVeloxPlanValidator.cpp index 365185a42c66..89bf331e9fcb 100644 --- a/velox/substrait/SubstraitToVeloxPlanValidator.cpp +++ b/velox/substrait/SubstraitToVeloxPlanValidator.cpp @@ -126,6 +126,139 @@ bool SubstraitToVeloxPlanValidator::validate( return true; } +bool validateBoundType(::substrait::Expression_WindowFunction_Bound boundType) { + switch (boundType.kind_case()) { + case ::substrait::Expression_WindowFunction_Bound::kUnboundedFollowing: + case ::substrait::Expression_WindowFunction_Bound::kUnboundedPreceding: + case ::substrait::Expression_WindowFunction_Bound::kCurrentRow: + break; + default: + std::cout << "The Bound Type is not supported. " + << "\n"; + return false; + } + return true; +} + +bool SubstraitToVeloxPlanValidator::validate( + const ::substrait::WindowRel& sWindow) { + if (sWindow.has_input() && !validate(sWindow.input())) { + return false; + } + + // Get and validate the input types from extension. + if (!sWindow.has_advanced_extension()) { + std::cout << "Input types are expected in WindowRel." << std::endl; + return false; + } + const auto& extension = sWindow.advanced_extension(); + std::vector types; + if (!validateInputTypes(extension, types)) { + std::cout << "Validation failed for input types in WindowRel." << std::endl; + return false; + } + + int32_t inputPlanNodeId = 0; + std::vector names; + names.reserve(types.size()); + for (auto colIdx = 0; colIdx < types.size(); colIdx++) { + names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx)); + } + auto rowType = std::make_shared(std::move(names), std::move(types)); + + // Validate WindowFunction + std::vector funcSpecs; + funcSpecs.reserve(sWindow.measures().size()); + for (const auto& smea : sWindow.measures()) { + try { + const auto& windowFunction = smea.measure(); + funcSpecs.emplace_back( + planConverter_->findFuncSpec(windowFunction.function_reference())); + toVeloxType(subParser_->parseType(windowFunction.output_type())->type); + for (const auto& arg : windowFunction.arguments()) { + auto typeCase = arg.value().rex_type_case(); + switch (typeCase) { + case ::substrait::Expression::RexTypeCase::kSelection: + case ::substrait::Expression::RexTypeCase::kLiteral: + break; + default: + std::cout << "Only field is supported in window functions." + << std::endl; + return false; + } + } + // Validate BoundType and Frame Type + switch (windowFunction.window_type()) { + case ::substrait::WindowType::ROWS: + case ::substrait::WindowType::RANGE: + break; + default: + VELOX_FAIL( + "the window type only support ROWS and RANGE, and the input type is ", + windowFunction.window_type()); + } + + validateBoundType(windowFunction.upper_bound()); + validateBoundType(windowFunction.lower_bound()); + + } catch (const VeloxException& err) { + std::cout << "Validation failed for window function due to: " + << err.message() << std::endl; + return false; + } + } + + // Validate groupby expression + const auto& groupByExprs = sWindow.partition_expressions(); + std::vector> expressions; + expressions.reserve(groupByExprs.size()); + try { + for (const auto& expr : groupByExprs) { + expressions.emplace_back(exprConverter_->toVeloxExpr(expr, rowType)); + } + // Try to compile the expressions. If there is any unregistred funciton or + // mismatched type, exception will be thrown. + exec::ExprSet exprSet(std::move(expressions), execCtx_); + } catch (const VeloxException& err) { + std::cout << "Validation failed for expression in ProjectRel due to:" + << err.message() << std::endl; + return false; + } + + // Validate Sort expression + const auto& sorts = sWindow.sorts(); + for (const auto& sort : sorts) { + switch (sort.direction()) { + case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST: + case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST: + case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST: + case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST: + break; + default: + return false; + } + + if (sort.has_expr()) { + try { + auto expression = exprConverter_->toVeloxExpr(sort.expr(), rowType); + auto expr_field = + dynamic_cast(expression.get()); + VELOX_CHECK( + expr_field != nullptr, + " the sorting key in Sort Operator only support field") + + exec::ExprSet exprSet({std::move(expression)}, execCtx_); + } catch (const VeloxException& err) { + std::cout << "Validation failed for expression in SortRel due to:" + << err.message() << std::endl; + return false; + } + } + } + + return true; +} + bool SubstraitToVeloxPlanValidator::validate( const ::substrait::SortRel& sSort) { if (sSort.has_input() && !validate(sSort.input())) { @@ -582,6 +715,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& sRel) { if (sRel.has_fetch()) { return validate(sRel.fetch()); } + if (sRel.has_window()) { + return validate(sRel.window()); + } return false; } diff --git a/velox/substrait/SubstraitToVeloxPlanValidator.h b/velox/substrait/SubstraitToVeloxPlanValidator.h index da58a745bf48..9ea547204757 100644 --- a/velox/substrait/SubstraitToVeloxPlanValidator.h +++ b/velox/substrait/SubstraitToVeloxPlanValidator.h @@ -38,6 +38,9 @@ class SubstraitToVeloxPlanValidator { /// Used to validate whether the computing of this Sort is supported. bool validate(const ::substrait::SortRel& sSort); + /// Used to validate whether the computing of this Window is supported. + bool validate(const ::substrait::WindowRel& sWindow); + /// Used to validate whether the computing of this Aggregation is supported. bool validate(const ::substrait::AggregateRel& sAgg); diff --git a/velox/substrait/proto/substrait/algebra.proto b/velox/substrait/proto/substrait/algebra.proto index a50dab28fedf..636073994f86 100644 --- a/velox/substrait/proto/substrait/algebra.proto +++ b/velox/substrait/proto/substrait/algebra.proto @@ -238,6 +238,19 @@ message SortRel { substrait.extensions.AdvancedExtension advanced_extension = 10; } +message WindowRel { + RelCommon common = 1; + Rel input = 2; + repeated Measure measures = 3; + repeated Expression partition_expressions = 4; + repeated SortField sorts = 5; + substrait.extensions.AdvancedExtension advanced_extension = 10; + + message Measure { + Expression.WindowFunction measure = 1; + } +} + // The relational operator capturing simple FILTERs (as in the WHERE clause of SQL) message FilterRel { RelCommon common = 1; @@ -389,6 +402,7 @@ message Rel { HashJoinRel hash_join = 13; MergeJoinRel merge_join = 14; ExpandRel expand = 15; + WindowRel window = 16; } } @@ -856,6 +870,9 @@ message Expression { // Optional; defaults to the start of the partition. Bound lower_bound = 5; + string column_name = 12; + WindowType window_type = 13; + // Defines the record relative to the current record up to which the window // extends. The bound is inclusive. If the upper bound indexes a record // less than the lower bound, TODO (null range/no records passed? @@ -887,10 +904,9 @@ message Expression { // Defines that the bound extends to or from the current record. message CurrentRow {} - // Defines an "unbounded bound": for lower bounds this means the start - // of the partition, and for upper bounds this means the end of the - // partition. - message Unbounded {} + message Unbounded_Preceding {} + + message Unbounded_Following {} oneof kind { // The bound extends some number of records behind the current record. @@ -903,10 +919,8 @@ message Expression { // The bound extends to the current record. CurrentRow current_row = 3; - // The bound extends to the start of the partition or the end of the - // partition, depending on whether this represents the upper or lower - // bound. - Unbounded unbounded = 4; + Unbounded_Preceding unbounded_preceding = 4; + Unbounded_Following unbounded_following = 5; } } } @@ -1244,6 +1258,11 @@ enum AggregationPhase { AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT = 4; } +enum WindowType { + ROWS = 0; + RANGE = 1; +} + // An aggregate function. message AggregateFunction { // Points to a function_anchor defined in this plan, which must refer