From f5a8b545331396cfeb9a91511536af04ef9bc299 Mon Sep 17 00:00:00 2001 From: Grigoriy Pisarenko Date: Wed, 27 Nov 2024 15:47:05 +0000 Subject: [PATCH 1/4] Supported pushdown for SafeCast, ToBytes, FlatMap --- .../providers/common/pushdown/collection.cpp | 1067 ++++++++--------- .../common/pushdown/physical_opt.cpp | 10 +- .../providers/common/pushdown/physical_opt.h | 4 + .../common/pushdown/predicate_node.cpp | 10 + .../common/pushdown/predicate_node.h | 1 + .../yql/providers/common/pushdown/settings.h | 4 +- .../api/service/protos/connector.proto | 8 + .../yql_generic_predicate_pushdown.cpp | 258 ++-- .../provider/yql_generic_predicate_pushdown.h | 1 + .../pushdown/yql_generic_match_predicate.cpp | 3 + .../pq/provider/yql_pq_logical_opt.cpp | 14 +- ydb/tests/fq/yds/test_row_dispatcher.py | 20 +- 12 files changed, 745 insertions(+), 655 deletions(-) diff --git a/ydb/library/yql/providers/common/pushdown/collection.cpp b/ydb/library/yql/providers/common/pushdown/collection.cpp index c2dc834ee8c5..1c1f38de3b0e 100644 --- a/ydb/library/yql/providers/common/pushdown/collection.cpp +++ b/ydb/library/yql/providers/common/pushdown/collection.cpp @@ -11,697 +11,644 @@ using namespace NNodes; namespace { -bool ExprHasUtf8Type(const TExprBase& expr) { - auto typeAnn = expr.Ptr()->GetTypeAnn(); - auto itemType = GetSeqItemType(typeAnn); - if (!itemType) { - itemType = typeAnn; - } - if (itemType->GetKind() != ETypeAnnotationKind::Data) { - return false; +// Build and markup a predicate tree whose leaves are expressions: +// A OR B +// | | +// C AND D COALESCE(X, Y, Z) +// | | | | | +// Member Comparation ..... +// +// Each node has flag if it can be pushed entirely, +// Next some tree nodes will be split +class TPredicateMarkup { + using EFlag = TSettings::EFeatureFlag; + +public: + TPredicateMarkup(const TExprBase& lambdaArg, const TSettings& settings) + : LambdaArg(lambdaArg) + , Settings(settings) + {} + + void MarkupPredicates(const TExprBase& predicate, TPredicateNode& predicateTree) { + if (auto coalesce = predicate.Maybe()) { + if (Settings.IsEnabled(EFlag::JustPassthroughOperators)) { + CollectChildrenPredicates(predicate, predicateTree); + } else { + predicateTree.CanBePushed = CoalesceCanBePushed(coalesce.Cast()); + } + } else if (auto compare = predicate.Maybe()) { + predicateTree.CanBePushed = CompareCanBePushed(compare.Cast()); + } else if (auto exists = predicate.Maybe()) { + predicateTree.CanBePushed = ExistsCanBePushed(exists.Cast()); + } else if (auto notOp = predicate.Maybe()) { + const auto value = notOp.Cast().Value(); + TPredicateNode child(value); + MarkupPredicates(value, child); + predicateTree.Op = EBoolOp::Not; + predicateTree.CanBePushed = child.CanBePushed; + predicateTree.Children.emplace_back(child); + } else if (predicate.Maybe()) { + predicateTree.Op = EBoolOp::And; + CollectChildrenPredicates(predicate, predicateTree); + } else if (predicate.Maybe()) { + predicateTree.Op = EBoolOp::Or; + CollectChildrenPredicates(predicate, predicateTree); + } else if (Settings.IsEnabled(EFlag::LogicalXorOperator) && predicate.Maybe()) { + predicateTree.Op = EBoolOp::Xor; + CollectChildrenPredicates(predicate, predicateTree); + } else if (auto jsonExists = predicate.Maybe()) { + predicateTree.CanBePushed = JsonExistsCanBePushed(jsonExists.Cast()); + } else if (Settings.IsEnabled(EFlag::JustPassthroughOperators) && (predicate.Maybe() || predicate.Maybe())) { + CollectChildrenPredicates(predicate, predicateTree); + } else if (auto sqlIn = predicate.Maybe()) { + predicateTree.CanBePushed = SqlInCanBePushed(sqlIn.Cast()); + } else if (predicate.Ref().IsCallable({"IsNotDistinctFrom", "IsDistinctFrom"})) { + predicateTree.CanBePushed = IsDistinctCanBePushed(predicate); + } else if (auto apply = predicate.Maybe()) { + predicateTree.CanBePushed = ApplyCanBePushed(apply.Cast()); + } else if (Settings.IsEnabled(EFlag::ExpressionAsPredicate)) { + predicateTree.CanBePushed = CheckExpressionNodeForPushdown(predicate); + } else { + predicateTree.CanBePushed = false; + } } - auto dataTypeInfo = NUdf::GetDataTypeInfo(itemType->Cast()->GetSlot()); - return (std::string(dataTypeInfo.Name.data()) == "Utf8"); -} -bool IsLikeOperator(const TCoCompare& predicate) { - return predicate.Maybe() - || predicate.Maybe() - || predicate.Maybe(); -} + void CollectChildrenPredicates(const TExprBase& node, TPredicateNode& predicateTree) { + predicateTree.Children.reserve(node.Ref().ChildrenSize()); + predicateTree.CanBePushed = true; + for (const auto& childNodePtr: node.Ref().Children()) { + TPredicateNode child(childNodePtr); + MarkupPredicates(TExprBase(childNodePtr), child); -bool IsSupportedLikeForUtf8(const TExprBase& left, const TExprBase& right) { - if ((left.Maybe() && ExprHasUtf8Type(left)) - || (right.Maybe() && ExprHasUtf8Type(right))) - { - return true; + predicateTree.Children.emplace_back(child); + predicateTree.CanBePushed &= child.CanBePushed; + } } - return false; -} -bool IsSupportedPredicate(const TCoCompare& predicate, const TSettings& settings) { - if (predicate.Maybe()) { - return true; - } else if (predicate.Maybe()) { - return true; - } else if (predicate.Maybe()) { - return true; - } else if (predicate.Maybe()) { - return true; - } else if (predicate.Maybe()) { - return true; - } else if (predicate.Maybe()) { - return true; - } else if (predicate.Maybe()) { - return true; - } else if (settings.IsEnabled(TSettings::EFeatureFlag::LikeOperator) && IsLikeOperator(predicate)) { - return true; - } else if (predicate.Maybe()) { - return true; - } else if (predicate.Maybe()) { - return true; +private: + // Type helpers + static std::optional DataSlotFromDataType(const TTypeAnnotationNode* typeAnn) { + if (!typeAnn || typeAnn->GetKind() != ETypeAnnotationKind::Data) { + return std::nullopt; + } + return typeAnn->Cast()->GetSlot(); } - return false; -} - -bool IsSupportedDataType(const TCoDataCtor& node, const TSettings& settings) { - if (node.Maybe() || - node.Maybe() || - node.Maybe() || - node.Maybe() || - node.Maybe() || - node.Maybe() || - node.Maybe() || - node.Maybe() || - node.Maybe() || - node.Maybe() || - node.Maybe()) - { - return true; + static std::optional DataSlotFromOptionalDataType(const TTypeAnnotationNode* typeAnn) { + if (typeAnn->GetKind() == ETypeAnnotationKind::Optional) { + typeAnn = typeAnn->Cast()->GetItemType(); + } + return DataSlotFromDataType(typeAnn); } - if (settings.IsEnabled(TSettings::EFeatureFlag::TimestampCtor) && node.Maybe()) { - return true; + static const TTypeAnnotationNode* UnwrapExprType(const TTypeAnnotationNode* typeAnn) { + if (!typeAnn) { + return nullptr; + } + if (const auto typeExpr = typeAnn->Cast()) { + return typeExpr->GetType(); + } + return nullptr; } - if (settings.IsEnabled(TSettings::EFeatureFlag::StringTypes)) { - if (node.Maybe() || node.Maybe()) { - return true; - } + static bool IsStringType(std::optional dataSlot) { + return dataSlot && (IsDataTypeString(*dataSlot) || *dataSlot == NUdf::EDataSlot::JsonDocument); } - return false; -} + static bool IsUtf8Type(std::optional dataSlot) { + return dataSlot == NUdf::EDataSlot::Utf8; + } -bool IsSupportedCast(const TCoSafeCast& cast, const TSettings& settings) { - if (!settings.IsEnabled(TSettings::EFeatureFlag::CastExpression)) { - return false; + static bool IsDateTimeType(std::optional dataSlot) { + return dataSlot && IsDataTypeDateOrTzDateOrInterval(*dataSlot); } - auto maybeDataType = cast.Type().Maybe(); - if (!maybeDataType) { - if (const auto maybeOptionalType = cast.Type().Maybe()) { - maybeDataType = maybeOptionalType.Cast().ItemType().Maybe(); - } + static bool IsUuidType(std::optional dataSlot) { + return dataSlot == NUdf::EDataSlot::Uuid; } - YQL_ENSURE(maybeDataType.IsValid()); - const auto dataType = maybeDataType.Cast(); - if (dataType.Type().Value() == "Int32") { // TODO: Support any numeric casts. - return cast.Value().Maybe() || cast.Value().Maybe(); + static bool IsDecimalType(std::optional dataSlot) { + return dataSlot == NUdf::EDataSlot::Decimal; } - return false; -} -bool IsStringType(NYql::NUdf::TDataTypeId t) { - return t == NYql::NProto::String - || t == NYql::NProto::Utf8 - || t == NYql::NProto::Yson - || t == NYql::NProto::Json - || t == NYql::NProto::JsonDocument; -} + static bool IsDyNumberType(std::optional dataSlot) { + return dataSlot == NUdf::EDataSlot::DyNumber; + } -bool IsDateTimeType(NYql::NUdf::TDataTypeId t) { - return t == NYql::NProto::Date - || t == NYql::NProto::Datetime - || t == NYql::NProto::Timestamp - || t == NYql::NProto::Interval - || t == NYql::NProto::TzDate - || t == NYql::NProto::TzDatetime - || t == NYql::NProto::TzTimestamp - || t == NYql::NProto::Date32 - || t == NYql::NProto::Datetime64 - || t == NYql::NProto::Timestamp64 - || t == NYql::NProto::Interval64; -} + static bool IsNumericType(std::optional dataSlot) { + return dataSlot && IsDataTypeNumeric(*dataSlot); + } -bool IsUuidType(NYql::NUdf::TDataTypeId t) { - return t == NYql::NProto::Uuid; -} + static bool IsSignedIntegralType(std::optional dataSlot) { + return dataSlot && IsDataTypeSigned(*dataSlot); + } -bool IsDecimalType(NYql::NUdf::TDataTypeId t) { - return t == NYql::NProto::Decimal; -} + static bool IsUnsignedIntegralType(std::optional dataSlot) { + return dataSlot && IsDataTypeUnsigned(*dataSlot); + } -bool IsDyNumberType(NYql::NUdf::TDataTypeId t) { - return t == NYql::NProto::DyNumber; -} + static bool IsComparableTypes(const TTypeAnnotationNode* left, const TTypeAnnotationNode* right, bool equality) { + if (equality) { + return CanCompare(left, right) != ECompareOptions::Uncomparable; + } + return CanCompare(left, right) != ECompareOptions::Uncomparable; + } -bool IsComparableTypes(const TExprBase& leftNode, const TExprBase& rightNode, bool equality, - const TTypeAnnotationNode* inputType, const TSettings& settings) -{ - const TExprNode::TPtr leftPtr = leftNode.Ptr(); - const TExprNode::TPtr rightPtr = rightNode.Ptr(); + static bool IsStringExpr(const TExprBase& expr) { + return IsStringType(DataSlotFromOptionalDataType(expr.Ref().GetTypeAnn())); + } - auto getDataType = [inputType](const TExprNode::TPtr& node) { - auto type = node->GetTypeAnn(); + static bool IsUtf8Expr(const TExprBase& expr) { + return IsUtf8Type(DataSlotFromOptionalDataType(expr.Ref().GetTypeAnn())); + } - if (type->GetKind() == ETypeAnnotationKind::Unit) { - auto rowType = inputType->Cast(); - type = rowType->FindItemType(node->Content()); - } + // Callable helpers + static bool IsSimpleLikeOperator(const TCoCompare& predicate) { + // Only cases $A LIKE $B, where $B: + // "%str", "str%", "%str%" + return predicate.Maybe() + || predicate.Maybe() + || predicate.Maybe(); + } - if (type->GetKind() == ETypeAnnotationKind::Optional) { - type = type->Cast()->GetItemType(); + static std::vector GetComparisonNodes(const TExprBase& node) { + std::vector result; + if (const auto maybeList = node.Maybe()) { + const auto nodeList = maybeList.Cast(); + result.reserve(nodeList.Size()); + for (size_t i = 0; i < nodeList.Size(); ++i) { + result.emplace_back(nodeList.Item(i)); + } + } else { + result.emplace_back(node); + } + return result; + } + +private: + // Genric expression checking + bool IsSupportedDataType(const TCoDataCtor& node) const { + if (node.Maybe() || + node.Maybe() || + node.Maybe() || + node.Maybe() || + node.Maybe() || + node.Maybe() || + node.Maybe() || + node.Maybe() || + node.Maybe() || + node.Maybe() || + node.Maybe()) { + return true; } - - return type; - }; - - auto defaultCompare = [equality](const TTypeAnnotationNode* left, const TTypeAnnotationNode* right) { - if (equality) { - return CanCompare(left, right); + if (Settings.IsEnabled(EFlag::TimestampCtor) && node.Maybe()) { + return true; } - - return CanCompare(left, right); - }; - - auto canCompare = [&defaultCompare, &settings](const TTypeAnnotationNode* left, const TTypeAnnotationNode* right) { - if (left->GetKind() != ETypeAnnotationKind::Data || - right->GetKind() != ETypeAnnotationKind::Data) - { - return defaultCompare(left, right); + if (Settings.IsEnabled(EFlag::StringTypes) && (node.Maybe() || node.Maybe())) { + return true; } + return false; + } - auto leftTypeId = GetDataTypeInfo(left->Cast()->GetSlot()).TypeId; - auto rightTypeId = GetDataTypeInfo(right->Cast()->GetSlot()).TypeId; + bool IsMemberColumn(const TCoMember& member) const { + // We allow member acces only for top level predicate argument + return member.Struct().Raw() == LambdaArg.Raw(); + } - if (!settings.IsEnabled(TSettings::EFeatureFlag::StringTypes) && (IsStringType(leftTypeId) || IsStringType(rightTypeId))) { - return ECompareOptions::Uncomparable; + bool IsMemberColumn(const TExprBase& node) const { + if (const auto member = node.Maybe()) { + return IsMemberColumn(member.Cast()); } + return false; + } - if (!settings.IsEnabled(TSettings::EFeatureFlag::DateTimeTypes) && (IsDateTimeType(leftTypeId) || IsDateTimeType(rightTypeId))) { - return ECompareOptions::Uncomparable; + bool IsSupportedSafeCast(const TCoSafeCast& cast) { + if (!Settings.IsEnabled(EFlag::CastExpression)) { + return false; } - if (!settings.IsEnabled(TSettings::EFeatureFlag::UuidType) && (IsUuidType(leftTypeId) || IsUuidType(rightTypeId))) { - return ECompareOptions::Uncomparable; + const auto targetType = DataSlotFromOptionalDataType(UnwrapExprType(cast.Type().Ref().GetTypeAnn())); + if (targetType == EDataSlot::Bool || IsNumericType(targetType) || IsStringType(targetType) && Settings.IsEnabled(EFlag::StringTypes)) { + return CheckExpressionNodeForPushdown(cast.Value()); } + return false; + } - if (!settings.IsEnabled(TSettings::EFeatureFlag::DecimalType) && (IsDecimalType(leftTypeId) || IsDecimalType(rightTypeId))) { - return ECompareOptions::Uncomparable; + bool IsSupportedToBytes(const TExprBase& toBytes) { + if (!Settings.IsEnabled(EFlag::ToBytesFromStringExpressions)) { + return false; } - - if (!settings.IsEnabled(TSettings::EFeatureFlag::DyNumberType) && (IsDyNumberType(leftTypeId) || IsDyNumberType(rightTypeId))) { - return ECompareOptions::Uncomparable; + if (toBytes.Ref().ChildrenSize() != 1) { + return false; } - if (leftTypeId == rightTypeId) { - return ECompareOptions::Comparable; + auto toBytesExpr = TExprBase(toBytes.Ref().Child(0)); + if (!IsStringExpr(toBytesExpr)) { + return false; } + return CheckExpressionNodeForPushdown(toBytesExpr); + } - /* - * Check special case UInt32 <-> Datetime in case i can't put it inside switch without lot of copypaste - */ - if (leftTypeId == NYql::NProto::Uint32 && rightTypeId == NYql::NProto::Date) { - return ECompareOptions::Comparable; + bool IsSupportedLambda(const TCoLambda& lambda, ui64 numberArguments) { + const auto args = lambda.Args(); + if (args.Size() != numberArguments) { + return false; } - /* - * SSA program requires strict equality of some types, otherwise columnshard fails to execute comparison - */ - switch (leftTypeId) { - case NYql::NProto::Int8: - case NYql::NProto::Int16: - case NYql::NProto::Int32: - // SSA program cast those values to Int32 - if (rightTypeId == NYql::NProto::Int8 || - rightTypeId == NYql::NProto::Int16 || - rightTypeId == NYql::NProto::Int32 || - (settings.IsEnabled(TSettings::EFeatureFlag::ImplicitConversionToInt64) && rightTypeId == NYql::NProto::Int64)) - { - return ECompareOptions::Comparable; - } - break; - case NYql::NProto::Uint16: - if (rightTypeId == NYql::NProto::Date) { - return ECompareOptions::Comparable; - } - [[fallthrough]]; - case NYql::NProto::Uint8: - case NYql::NProto::Uint32: - // SSA program cast those values to Uint32 - if (rightTypeId == NYql::NProto::Uint8 || - rightTypeId == NYql::NProto::Uint16 || - rightTypeId == NYql::NProto::Uint32 || - (settings.IsEnabled(TSettings::EFeatureFlag::ImplicitConversionToInt64) && rightTypeId == NYql::NProto::Uint64)) - { - return ECompareOptions::Comparable; - } - break; - case NYql::NProto::Date: - // See arcadia/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp SwitchMiniKQLDataTypeToArrowType - if (rightTypeId == NYql::NProto::Uint16) { - return ECompareOptions::Comparable; - } - break; - case NYql::NProto::Datetime: - // See arcadia/ydb/library/yql/dq/runtime/dq_arrow_helpers.cpp SwitchMiniKQLDataTypeToArrowType - if (rightTypeId == NYql::NProto::Uint32) { - return ECompareOptions::Comparable; - } - break; - case NYql::NProto::Int64: - if (settings.IsEnabled(TSettings::EFeatureFlag::ImplicitConversionToInt64) && ( - rightTypeId == NYql::NProto::Int8 || - rightTypeId == NYql::NProto::Int16 || - rightTypeId == NYql::NProto::Int32)) - { - return ECompareOptions::Comparable; - } - break; - case NYql::NProto::Uint64: - if (settings.IsEnabled(TSettings::EFeatureFlag::ImplicitConversionToInt64) && ( - rightTypeId == NYql::NProto::Uint8 || - rightTypeId == NYql::NProto::Uint16 || - rightTypeId == NYql::NProto::Uint32)) - { - return ECompareOptions::Comparable; - } - break; - case NYql::NProto::Bool: - case NYql::NProto::Float: - case NYql::NProto::Double: - case NYql::NProto::Decimal: - case NYql::NProto::Timestamp: - case NYql::NProto::Interval: - // Obviosly here right node has not same type as left one - break; - default: - return defaultCompare(left, right); + // Add arguments into current context + for (const auto& argPtr : args.Ref().Children()) { + YQL_ENSURE(LambdaArguments.insert(argPtr.Get()).second, "Found duplicated lambda argument"); } - return ECompareOptions::Uncomparable; - }; + const bool result = CheckExpressionNodeForPushdown(lambda.Body()); - auto leftType = getDataType(leftPtr); - auto rightType = getDataType(rightPtr); - - if (canCompare(leftType, rightType) == ECompareOptions::Uncomparable) { - YQL_CVLOG(NLog::ELevel::DEBUG, settings.GetLogComponent()) << "Pushdown: " - << "Uncompatible types in compare of nodes: " - << leftPtr->Content() << " of type " << FormatType(leftType) - << " and " - << rightPtr->Content() << " of type " << FormatType(rightType); + // Remove arguments from current context + for (const auto& argPtr : args.Ref().Children()) { + LambdaArguments.erase(argPtr.Get()); + } - return false; + return result; } - return true; -} + bool IsSupportedFlatMap(const TCoFlatMap& flatMap) { + if (!Settings.IsEnabled(EFlag::FlatMapOverOptionals)) { + return false; + } -std::vector GetComparisonNodes(const TExprBase& node) { - std::vector res; - if (node.Maybe()) { - auto nodeList = node.Cast(); - res.reserve(nodeList.Size()); - for (size_t i = 0; i < nodeList.Size(); ++i) { - res.emplace_back(nodeList.Item(i)); + const auto input = flatMap.Input(); + if (!DataSlotFromOptionalDataType(input.Ref().GetTypeAnn())) { + // Supported only simple flat map over one optional + return false; + } + if (!CheckExpressionNodeForPushdown(input)) { + return false; } - } else { - res.emplace_back(node); - } - return res; -} -bool IsMemberColumn(const TCoMember& member, const TExprBase& lambdaArg) { - return member.Struct().Raw() == lambdaArg.Raw(); -} + // Expected exactly one argument for flat map lambda + return IsSupportedLambda(flatMap.Lambda(), 1); + } -bool IsMemberColumn(const TExprBase& node, const TExprBase& lambdaArg) { - if (auto member = node.Maybe()) { - return IsMemberColumn(member.Cast(), lambdaArg); + bool IsLambdaArgument(const TExprBase& expr) const { + return LambdaArguments.contains(expr.Raw()); } - return false; -} -bool CheckExpressionNodeForPushdown(const TExprBase& node, const TExprBase& lambdaArg, const TExprBase& lambdaBody, const TSettings& settings) { - if (auto maybeSafeCast = node.Maybe()) { - return IsSupportedCast(maybeSafeCast.Cast(), settings); - } else if (auto maybeData = node.Maybe()) { - return IsSupportedDataType(maybeData.Cast(), settings); - } else if (auto maybeMember = node.Maybe()) { - return IsMemberColumn(maybeMember.Cast(), lambdaArg); - } else if (settings.IsEnabled(TSettings::EFeatureFlag::JsonQueryOperators) && node.Maybe()) { - if (!node.Maybe()) { - return false; + bool CheckExpressionNodeForPushdown(const TExprBase& node) { + if (auto maybeSafeCast = node.Maybe()) { + return IsSupportedSafeCast(maybeSafeCast.Cast()); + } + if (node.Ref().IsCallable({"ToBytes"})) { + return IsSupportedToBytes(node); + } + if (auto maybeData = node.Maybe()) { + return IsSupportedDataType(maybeData.Cast()); } - auto jsonOp = node.Cast(); - if (!jsonOp.Json().Maybe() || !jsonOp.JsonPath().Maybe()) { + if (auto maybeMember = node.Maybe()) { + return IsMemberColumn(maybeMember.Cast()); + } + if (Settings.IsEnabled(EFlag::JsonQueryOperators) && node.Maybe()) { + if (!node.Maybe()) { + return false; + } + // Currently we support only simple columns in pushdown - return false; + const auto jsonOp = node.Cast(); + return jsonOp.Json().Maybe() && jsonOp.JsonPath().Maybe(); } - return true; - } else if (node.Maybe()) { - return true; - } else if (settings.IsEnabled(TSettings::EFeatureFlag::ParameterExpression) && node.Maybe()) { - return true; - } else if (const auto op = node.Maybe(); op && settings.IsEnabled(TSettings::EFeatureFlag::UnaryOperators)) { - return CheckExpressionNodeForPushdown(op.Cast().Arg(), lambdaArg, lambdaBody, settings); - } else if (const auto op = node.Maybe(); op && settings.IsEnabled(TSettings::EFeatureFlag::ArithmeticalExpressions)) { - if (!settings.IsEnabled(TSettings::EFeatureFlag::DivisionExpressions) && (op.Maybe() || op.Maybe())) { - return false; + if (node.Maybe()) { + return true; } - return CheckExpressionNodeForPushdown(op.Cast().Left(), lambdaArg, lambdaBody, settings) && CheckExpressionNodeForPushdown(op.Cast().Right(), lambdaArg, lambdaBody, settings); - } else if (settings.IsEnabled(TSettings::EFeatureFlag::JustPassthroughOperators) && (node.Maybe() || node.Maybe())) { - for (const auto& childNodePtr : node.Ref().Children()) { - if (!CheckExpressionNodeForPushdown(TExprBase(childNodePtr), lambdaArg, lambdaBody, settings)) { + if (Settings.IsEnabled(EFlag::ParameterExpression) && node.Maybe()) { + return true; + } + if (const auto op = node.Maybe(); op && Settings.IsEnabled(EFlag::UnaryOperators)) { + return CheckExpressionNodeForPushdown(op.Cast().Arg()); + } + if (const auto op = node.Maybe(); op && Settings.IsEnabled(EFlag::ArithmeticalExpressions)) { + if (!Settings.IsEnabled(EFlag::DivisionExpressions) && (op.Maybe() || op.Maybe())) { return false; } + return CheckExpressionNodeForPushdown(op.Cast().Left()) && CheckExpressionNodeForPushdown(op.Cast().Right()); } - return true; - } else if (auto maybeIf = node.Maybe()) { - if (!settings.IsEnabled(TSettings::EFeatureFlag::JustPassthroughOperators)) { - return false; + if (Settings.IsEnabled(EFlag::JustPassthroughOperators) && (node.Maybe() || node.Maybe())) { + for (const auto& childNodePtr : node.Ref().Children()) { + if (!CheckExpressionNodeForPushdown(TExprBase(childNodePtr))) { + return false; + } + } + return true; } + if (const auto maybeIf = node.Maybe(); maybeIf && Settings.IsEnabled(EFlag::JustPassthroughOperators)) { + const auto& sqlIf = maybeIf.Cast(); + const auto& predicate = sqlIf.Predicate(); - const auto& sqlIf = maybeIf.Cast(); - const auto& predicate = sqlIf.Predicate(); - - // Check if predicate pushdown - TPredicateNode ifPredicate(predicate); - CollectPredicates(TExprBase(predicate), ifPredicate, lambdaArg, lambdaBody, settings); + // Check if predicate pushdown + TPredicateNode ifPredicate(predicate); + MarkupPredicates(predicate, ifPredicate); - // Check if expressions pushdown - return ifPredicate.CanBePushed - && CheckExpressionNodeForPushdown(sqlIf.ThenValue(), lambdaArg, lambdaBody, settings) - && CheckExpressionNodeForPushdown(sqlIf.ElseValue(), lambdaArg, lambdaBody, settings); + // Check if expressions pushdown + return ifPredicate.CanBePushed + && CheckExpressionNodeForPushdown(sqlIf.ThenValue()) + && CheckExpressionNodeForPushdown(sqlIf.ElseValue()); + } + if (auto flatMap = node.Maybe()) { + return IsSupportedFlatMap(flatMap.Cast()); + } + return IsLambdaArgument(node); } - return false; -} -bool CheckComparisonParametersForPushdown(const TCoCompare& compare, const TExprBase& lambdaArg, const TExprBase& input, const TSettings& settings) { - const TTypeAnnotationNode* inputType = input.Ptr()->GetTypeAnn(); - switch (inputType->GetKind()) { - case ETypeAnnotationKind::Flow: - inputType = inputType->Cast()->GetItemType(); - break; - case ETypeAnnotationKind::Stream: - inputType = inputType->Cast()->GetItemType(); - break; - case ETypeAnnotationKind::Struct: - break; - default: - YQL_ENSURE(false, "Unsupported type of incoming data: " << (ui32)inputType->GetKind()); - // We do not know how process input that is not a sequence of elements +private: + // Comprasion checking + bool IsSupportedLikeOperator(const TCoCompare& compare) const { + if (!IsSimpleLikeOperator(compare)) { return false; + } + if (Settings.IsEnabled(EFlag::LikeOperator)) { + return true; + } + if (Settings.IsEnabled(EFlag::LikeOperatorOnlyForUtf8) && IsUtf8Expr(compare.Left()) && IsUtf8Expr(compare.Right())) { + return true; + } + return false; } - YQL_ENSURE(inputType->GetKind() == ETypeAnnotationKind::Struct); - if (inputType->GetKind() != ETypeAnnotationKind::Struct) { - // We do not know how process input that is not a sequence of elements + bool IsSupportedCompareOperator(const TCoCompare& compare) const { + if (compare.Maybe() || + compare.Maybe() || + compare.Maybe() || + compare.Maybe() || + compare.Maybe() || + compare.Maybe() || + compare.Maybe() || + compare.Maybe()) { + return true; + } + if (IsSupportedLikeOperator(compare)) { + return true; + } return false; } - const auto leftList = GetComparisonNodes(compare.Left()); - const auto rightList = GetComparisonNodes(compare.Right()); - YQL_ENSURE(leftList.size() == rightList.size(), "Different sizes of lists in comparison!"); + bool IsComparableArguments(const TExprBase& left, const TExprBase& right, bool equality) const { + if (Settings.IsEnabled(EFlag::DoNotCheckCompareArgumentsTypes)) { + return true; + } - for (size_t i = 0; i < leftList.size(); ++i) { - if (!CheckExpressionNodeForPushdown(leftList[i], lambdaArg, input, settings) || !CheckExpressionNodeForPushdown(rightList[i], lambdaArg, input, settings)) { + const auto leftType = DataSlotFromOptionalDataType(left.Ref().GetTypeAnn()); + const auto rightType = DataSlotFromOptionalDataType(right.Ref().GetTypeAnn()); + if (!leftType || !rightType) { + return IsComparableTypes(left.Ref().GetTypeAnn(), right.Ref().GetTypeAnn(), equality); + } + if (!Settings.IsEnabled(EFlag::StringTypes) && (IsStringType(leftType) || IsStringType(rightType))) { return false; } - - if (!settings.IsEnabled(TSettings::EFeatureFlag::DoNotCheckCompareArgumentsTypes)) { - if (!IsComparableTypes(leftList[i], rightList[i], compare.Maybe() || compare.Maybe(), inputType, settings)) { - return false; - } + if (!Settings.IsEnabled(EFlag::DateTimeTypes) && (IsDateTimeType(leftType) || IsDateTimeType(rightType))) { + return false; } - - if (IsLikeOperator(compare) && settings.IsEnabled(TSettings::EFeatureFlag::LikeOperatorOnlyForUtf8) && !IsSupportedLikeForUtf8(leftList[i], rightList[i])) { - // (KQP OLAP) If SSA_RUNTIME_VERSION == 2 Column Shard doesn't have LIKE kernel for binary strings + if (!Settings.IsEnabled(EFlag::UuidType) && (IsUuidType(leftType) || IsUuidType(rightType))) { return false; } - } - - return true; -} - -bool CompareCanBePushed(const TCoCompare& compare, const TExprBase& lambdaArg, const TExprBase& lambdaBody, const TSettings& settings) { - if (!IsSupportedPredicate(compare, settings)) { - return false; - } - - if (!CheckComparisonParametersForPushdown(compare, lambdaArg, lambdaBody, settings)) { - return false; - } - - return true; -} - -bool SqlInCanBePushed(const TCoSqlIn& sqlIn, const TExprBase& lambdaArg, const TExprBase& lambdaBody, const TSettings& settings) { - const TExprBase& expr = sqlIn.Collection(); - const TExprBase& lookup = sqlIn.Lookup(); + if (!Settings.IsEnabled(EFlag::DecimalType) && (IsDecimalType(leftType) || IsDecimalType(rightType))) { + return false; + } + if (!Settings.IsEnabled(EFlag::DyNumberType) && (IsDyNumberType(leftType) || IsDyNumberType(rightType))) { + return false; + } + if (leftType == rightType) { + return true; + } - if (!CheckExpressionNodeForPushdown(lookup, lambdaArg, lambdaBody, settings)) { - return false; - } + // We check: + // - signed / unsigned quality by each side + // - sizes of data types like data / interval + switch (*leftType) { + case NUdf::EDataSlot::Int8: + case NUdf::EDataSlot::Int16: + case NUdf::EDataSlot::Int32: + case NUdf::EDataSlot::Int64: + return Settings.IsEnabled(EFlag::ImplicitConversionToInt64) && IsSignedIntegralType(rightType); + + case NUdf::EDataSlot::Uint8: + case NUdf::EDataSlot::Uint16: + if (rightType == NUdf::EDataSlot::Date) { + return true; + } + [[fallthrough]]; + case NUdf::EDataSlot::Uint32: + if (rightType == NUdf::EDataSlot::Datetime) { + return true; + } + [[fallthrough]]; + case NUdf::EDataSlot::Uint64: + if (rightType == NUdf::EDataSlot::Timestamp || rightType == NUdf::EDataSlot::Interval) { + return true; + } + return Settings.IsEnabled(EFlag::ImplicitConversionToInt64) && IsUnsignedIntegralType(rightType); + + case NUdf::EDataSlot::Date: + return rightType == NUdf::EDataSlot::Uint16; + case NUdf::EDataSlot::Datetime: + return rightType == NUdf::EDataSlot::Uint32; + case NUdf::EDataSlot::Timestamp: + case NUdf::EDataSlot::Interval: + return rightType == NUdf::EDataSlot::Uint64; + + case NUdf::EDataSlot::Bool: + case NUdf::EDataSlot::Float: + case NUdf::EDataSlot::Double: + case NUdf::EDataSlot::Decimal: + return false; - TExprNode::TPtr collection; - if (expr.Ref().IsList()) { - collection = expr.Ptr(); - } else if (auto maybeAsList = expr.Maybe()) { - collection = maybeAsList.Cast().Ptr(); - } else { - return false; + default: + return IsComparableTypes(left.Ref().GetTypeAnn(), right.Ref().GetTypeAnn(), equality); + } } - const TTypeAnnotationNode* inputType = lambdaBody.Ptr()->GetTypeAnn(); - for (auto& child : collection->Children()) { - if (!CheckExpressionNodeForPushdown(TExprBase(child), lambdaArg, lambdaBody, settings)) { - return false; + bool IsSupportedComparisonParameters(const TCoCompare& compare) { + const TTypeAnnotationNode* inputType = LambdaArg.Ptr()->GetTypeAnn(); + switch (inputType->GetKind()) { + case ETypeAnnotationKind::Flow: + inputType = inputType->Cast()->GetItemType(); + break; + case ETypeAnnotationKind::Stream: + inputType = inputType->Cast()->GetItemType(); + break; + case ETypeAnnotationKind::Struct: + break; + default: + // We do not know how process input that is not a sequence of elements + return false; } + YQL_ENSURE(inputType->GetKind() == ETypeAnnotationKind::Struct, "Unexpected predicate input type " << ui64(inputType->GetKind())); - if (!settings.IsEnabled(TSettings::EFeatureFlag::DoNotCheckCompareArgumentsTypes)) { - if (!IsComparableTypes(lookup, TExprBase(child), false, inputType, settings)) { + const auto leftList = GetComparisonNodes(compare.Left()); + const auto rightList = GetComparisonNodes(compare.Right()); + YQL_ENSURE(leftList.size() == rightList.size(), "Compression parameters should have same size but got " << leftList.size() << " vs " << rightList.size()); + + for (size_t i = 0; i < leftList.size(); ++i) { + if (!CheckExpressionNodeForPushdown(leftList[i]) || !CheckExpressionNodeForPushdown(rightList[i])) { + return false; + } + if (!IsComparableArguments(leftList[i], rightList[i], compare.Maybe() || compare.Maybe())) { return false; } } + return true; } - return true; -} -bool IsDistinctCanBePushed(const TExprBase& predicate, const TExprBase& lambdaArg, const TExprBase& lambdaBody, const TSettings& settings) { - if (predicate.Ref().ChildrenSize() != 2 ) { - return false; - } - auto expr1 = TExprBase(predicate.Ref().Child(0)); - auto expr2 = TExprBase(predicate.Ref().Child(1)); - if (!CheckExpressionNodeForPushdown(expr1, lambdaArg, lambdaBody, settings) - || !CheckExpressionNodeForPushdown(expr2, lambdaArg, lambdaBody, settings)) { - return false; - } - if (!settings.IsEnabled(TSettings::EFeatureFlag::DoNotCheckCompareArgumentsTypes) - && !IsComparableTypes(expr1, expr2, false, lambdaBody.Ptr()->GetTypeAnn(), settings)) { - return false; + bool CompareCanBePushed(const TCoCompare& compare) { + if (!IsSupportedCompareOperator(compare)) { + return false; + } + if (!IsSupportedComparisonParameters(compare)) { + return false; + } + return true; } - return true; -} -bool SafeCastCanBePushed(const TCoFlatMap& flatmap, const TExprBase& lambdaArg, const TExprBase& lambdaBody, const TSettings& settings) { - /* - * There are three ways of comparison in following format: - * - * FlatMap (LeftArgument, FlatMap(RightArgument(), Just(Predicate)) - * - * Examples: - * FlatMap (SafeCast(), FlatMap(Member(), Just(Comparison)) - * FlatMap (Member(), FlatMap(SafeCast(), Just(Comparison)) - * FlatMap (SafeCast(), FlatMap(SafeCast(), Just(Comparison)) - */ - auto maybeFlatmap = flatmap.Lambda().Body().Maybe(); - if (!maybeFlatmap.IsValid()) { - return false; - } +private: + // Boolean predicates checking + bool SqlInCanBePushed(const TCoSqlIn& sqlIn) { + if (!Settings.IsEnabled(EFlag::InOperator)) { + return false; + } - auto leftList = GetComparisonNodes(flatmap.Input()); - auto rightList = GetComparisonNodes(maybeFlatmap.Cast().Input()); - YQL_ENSURE(leftList.size() == rightList.size(), "Different sizes of lists in comparison!"); + const TExprBase& expr = sqlIn.Collection(); + const TExprBase& lookup = sqlIn.Lookup(); + if (!CheckExpressionNodeForPushdown(lookup)) { + return false; + } - for (size_t i = 0; i < leftList.size(); ++i) { - if (!CheckExpressionNodeForPushdown(leftList[i], lambdaArg, lambdaBody, settings) || !CheckExpressionNodeForPushdown(rightList[i], lambdaArg, lambdaBody, settings)) { + TExprNode::TPtr collection; + if (expr.Ref().IsList()) { + collection = expr.Ptr(); + } else if (auto maybeAsList = expr.Maybe()) { + collection = maybeAsList.Cast().Ptr(); + } else { return false; } - } - auto maybeJust = maybeFlatmap.Cast().Lambda().Body().Maybe(); - if (!maybeJust.IsValid()) { - return false; + for (const auto& childNodePtr : collection->Children()) { + if (!CheckExpressionNodeForPushdown(TExprBase(childNodePtr))) { + return false; + } + if (!IsComparableArguments(lookup, TExprBase(childNodePtr), true)) { + return false; + } + } + return true; } - auto maybePredicate = maybeJust.Cast().Input().Maybe(); - if (!maybePredicate.IsValid()) { - return false; - } + bool IsDistinctCanBePushed(const TExprBase& predicate) { + if (!Settings.IsEnabled(EFlag::IsDistinctOperator)) { + return false; + } + if (predicate.Ref().ChildrenSize() != 2) { + return false; + } - auto predicate = maybePredicate.Cast(); - if (!IsSupportedPredicate(predicate, settings)) { - return false; + const auto left = TExprBase(predicate.Ref().Child(0)); + const auto right = TExprBase(predicate.Ref().Child(1)); + if (!CheckExpressionNodeForPushdown(left) || !CheckExpressionNodeForPushdown(right)) { + return false; + } + return IsComparableArguments(left, right, true); } - return true; -} + bool JsonExistsCanBePushed(const TCoJsonExists& jsonExists) const { + if (!Settings.IsEnabled(EFlag::JsonExistsOperator)) { + return false; + } -bool JsonExistsCanBePushed(const TCoJsonExists& jsonExists, const TExprBase& lambdaArg) { - auto maybeMember = jsonExists.Json().Maybe(); - if (!maybeMember || !jsonExists.JsonPath().Maybe()) { - // Currently we support only simple columns in pushdown - return false; - } - if (!IsMemberColumn(maybeMember.Cast(), lambdaArg)) { - return false; + const auto maybeMember = jsonExists.Json().Maybe(); + if (!maybeMember || !jsonExists.JsonPath().Maybe()) { + // Currently we support only simple columns in pushdown + return false; + } + return IsMemberColumn(maybeMember.Cast()); } - return true; -} -bool CoalesceCanBePushed(const TCoCoalesce& coalesce, const TExprBase& lambdaArg, const TExprBase& lambdaBody, const TSettings& settings) { - if (!coalesce.Value().Maybe()) { - return false; - } - auto predicate = coalesce.Predicate(); + bool CoalesceCanBePushed(const TCoCoalesce& coalesce) { + if (!coalesce.Value().Maybe()) { + return false; + } - if (auto maybeCompare = predicate.Maybe()) { - return CompareCanBePushed(maybeCompare.Cast(), lambdaArg, lambdaBody, settings); - } else if (auto maybeFlatmap = predicate.Maybe()) { - return SafeCastCanBePushed(maybeFlatmap.Cast(), lambdaArg, lambdaBody, settings); - } else if (settings.IsEnabled(TSettings::EFeatureFlag::JsonExistsOperator) && predicate.Maybe()) { - auto jsonExists = predicate.Cast(); - return JsonExistsCanBePushed(jsonExists, lambdaArg); + TPredicateNode predicateTree(coalesce.Predicate()); + MarkupPredicates(coalesce.Predicate(), predicateTree); + return predicateTree.CanBePushed; } - return false; -} - -bool ExistsCanBePushed(const TCoExists& exists, const TExprBase& lambdaArg) { - return IsMemberColumn(exists.Optional(), lambdaArg); -} - -bool UdfCanBePushed(const TCoUdf& udf, const TExprNode::TListType& children, const TExprBase& lambdaArg, const TExprBase& lambdaBody, const TSettings& settings) { - const TString functionName(udf.MethodName()); - if (!settings.IsEnabledFunction(functionName)) { - return false; + bool ExistsCanBePushed(const TCoExists& exists) const { + return IsMemberColumn(exists.Optional()); } - if (functionName == "Re2.Grep") { - if (children.size() != 2) { - // Expected exactly one argument (first child of apply is callable) + bool UdfCanBePushed(const TCoUdf& udf, const TExprNode::TListType& children) { + const TString functionName(udf.MethodName()); + if (!Settings.IsEnabledFunction(functionName)) { return false; } - const auto& udfSettings = udf.Settings(); - if (udfSettings && !udfSettings.Cast().Empty()) { - // Expected empty udf settings - return false; - } + if (functionName == "Re2.Grep") { + if (children.size() != 2) { + // Expected exactly one argument (first child of apply is callable) + return false; + } - const auto& maybeRunConfig = udf.RunConfigValue(); - if (!maybeRunConfig) { - // Expected non empty run config - return false; - } - const auto& runConfig = maybeRunConfig.Cast().Ref(); + const auto& udfSettings = udf.Settings(); + if (udfSettings && !udfSettings.Cast().Empty()) { + // Expected empty udf settings + return false; + } - if (runConfig.ChildrenSize() != 2) { - // Expected exactly two run config settings - return false; - } - if (!TExprBase(runConfig.Child(1)).Maybe()) { - // Expected empty regexp settings - return false; - } + const auto& maybeRunConfig = udf.RunConfigValue(); + if (!maybeRunConfig) { + // Expected non empty run config + return false; + } + const auto& runConfig = maybeRunConfig.Cast().Ref(); - return CheckExpressionNodeForPushdown(TExprBase(runConfig.Child(0)), lambdaArg, lambdaBody, settings); - } - return false; -} + if (runConfig.ChildrenSize() != 2) { + // Expected exactly two run config settings + return false; + } + if (!TExprBase(runConfig.Child(1)).Maybe()) { + // Expected empty regexp settings + return false; + } -bool ApplyCanBePushed(const TCoApply& apply, const TExprBase& lambdaArg, const TExprBase& lambdaBody, const TSettings& settings) { - // Check callable - if (auto udf = apply.Callable().Maybe()) { - if (!UdfCanBePushed(udf.Cast(), apply.Ref().ChildrenList(), lambdaArg, lambdaBody, settings)) { - return false; + return CheckExpressionNodeForPushdown(TExprBase(runConfig.Child(0))); } + return false; } - // Check arguments - for (size_t i = 1; i < apply.Ref().ChildrenSize(); ++i) { - if (!CheckExpressionNodeForPushdown(TExprBase(apply.Ref().Child(i)), lambdaArg, lambdaBody, settings)) { - return false; + bool ApplyCanBePushed(const TCoApply& apply) { + // Check callable + if (auto udf = apply.Callable().Maybe()) { + if (!UdfCanBePushed(udf.Cast(), apply.Ref().ChildrenList())) { + return false; + } } - } - return true; -} -void CollectChildrenPredicates(const TExprNode& opNode, TPredicateNode& predicateTree, const TExprBase& lambdaArg, const TExprBase& lambdaBody, const TSettings& settings) { - predicateTree.Children.reserve(opNode.ChildrenSize()); - predicateTree.CanBePushed = true; - for (const auto& childNodePtr: opNode.Children()) { - TPredicateNode child(childNodePtr); - const TExprBase base(childNodePtr); - if (const auto maybeCtor = base.Maybe()) - child.CanBePushed = IsSupportedDataType(maybeCtor.Cast(), settings); - else - CollectPredicates(base, child, lambdaArg, lambdaBody, settings); - predicateTree.Children.emplace_back(child); - predicateTree.CanBePushed &= child.CanBePushed; + // Check arguments + for (size_t i = 1; i < apply.Ref().ChildrenSize(); ++i) { + if (!CheckExpressionNodeForPushdown(TExprBase(apply.Ref().Child(i)))) { + return false; + } + } + return true; } -} -void CollectExpressionPredicate(TPredicateNode& predicateTree, const TCoMember& member, const TExprBase& lambdaArg) { - predicateTree.CanBePushed = IsMemberColumn(member, lambdaArg); -} +private: + const TExprBase& LambdaArg; // Predicate input item, has struct type + const TSettings& Settings; + + std::unordered_set LambdaArguments; +}; } // anonymous namespace end void CollectPredicates(const TExprBase& predicate, TPredicateNode& predicateTree, const TExprBase& lambdaArg, const TExprBase& lambdaBody, const TSettings& settings) { - if (predicate.Maybe()) { - if (settings.IsEnabled(TSettings::EFeatureFlag::JustPassthroughOperators)) - CollectChildrenPredicates(predicate.Ref(), predicateTree, lambdaArg, lambdaBody, settings); - else { - auto coalesce = predicate.Cast(); - predicateTree.CanBePushed = CoalesceCanBePushed(coalesce, lambdaArg, lambdaBody, settings); - } - } else if (predicate.Maybe()) { - auto compare = predicate.Cast(); - predicateTree.CanBePushed = CompareCanBePushed(compare, lambdaArg, lambdaBody, settings); - } else if (predicate.Maybe()) { - auto exists = predicate.Cast(); - predicateTree.CanBePushed = ExistsCanBePushed(exists, lambdaArg); - } else if (predicate.Maybe()) { - predicateTree.Op = EBoolOp::Not; - auto notOp = predicate.Cast(); - TPredicateNode child(notOp.Value()); - CollectPredicates(notOp.Value(), child, lambdaArg, lambdaBody, settings); - predicateTree.CanBePushed = child.CanBePushed; - predicateTree.Children.emplace_back(child); - } else if (predicate.Maybe()) { - predicateTree.Op = EBoolOp::And; - CollectChildrenPredicates(predicate.Ref(), predicateTree, lambdaArg, lambdaBody, settings); - } else if (predicate.Maybe()) { - predicateTree.Op = EBoolOp::Or; - CollectChildrenPredicates(predicate.Ref(), predicateTree, lambdaArg, lambdaBody, settings); - } else if (settings.IsEnabled(TSettings::EFeatureFlag::LogicalXorOperator) && predicate.Maybe()) { - predicateTree.Op = EBoolOp::Xor; - CollectChildrenPredicates(predicate.Ref(), predicateTree, lambdaArg, lambdaBody, settings); - } else if (settings.IsEnabled(TSettings::EFeatureFlag::JsonExistsOperator) && predicate.Maybe()) { - auto jsonExists = predicate.Cast(); - predicateTree.CanBePushed = JsonExistsCanBePushed(jsonExists, lambdaArg); - } else if (settings.IsEnabled(TSettings::EFeatureFlag::ExpressionAsPredicate) && predicate.Maybe()) { - CollectExpressionPredicate(predicateTree, predicate.Cast(), lambdaArg); - } else if (settings.IsEnabled(TSettings::EFeatureFlag::JustPassthroughOperators) && (predicate.Maybe() || predicate.Maybe())) { - CollectChildrenPredicates(predicate.Ref(), predicateTree, lambdaArg, lambdaBody, settings); - } else if (settings.IsEnabled(TSettings::EFeatureFlag::InOperator) && predicate.Maybe()) { - auto sqlIn = predicate.Cast(); - predicateTree.CanBePushed = SqlInCanBePushed(sqlIn, lambdaArg, lambdaBody, settings); - } else if (settings.IsEnabled(TSettings::EFeatureFlag::IsDistinctOperator) && - (predicate.Ref().IsCallable({"IsNotDistinctFrom", "IsDistinctFrom"}))) { - predicateTree.CanBePushed = IsDistinctCanBePushed(predicate, lambdaArg, lambdaBody, settings); - } else if (auto maybeApply = predicate.Maybe()) { - predicateTree.CanBePushed = ApplyCanBePushed(maybeApply.Cast(), lambdaArg, lambdaBody, settings); - } else { - predicateTree.CanBePushed = false; - } + TPredicateMarkup markup(lambdaArg, settings); + markup.MarkupPredicates(predicate, predicateTree); } } // namespace NYql::NPushdown diff --git a/ydb/library/yql/providers/common/pushdown/physical_opt.cpp b/ydb/library/yql/providers/common/pushdown/physical_opt.cpp index ed2b401de8d5..7f31c520d212 100644 --- a/ydb/library/yql/providers/common/pushdown/physical_opt.cpp +++ b/ydb/library/yql/providers/common/pushdown/physical_opt.cpp @@ -39,7 +39,7 @@ TPredicateNode SplitForPartialPushdown(const NPushdown::TPredicateNode& predicat } -TMaybeNode MakePushdownPredicate(const TCoLambda& lambda, TExprContext& ctx, const TPositionHandle& pos, const TSettings& settings) { +NPushdown::TPredicateNode MakePushdownNode(const NNodes::TCoLambda& lambda, TExprContext& ctx, const TPositionHandle& pos, const TSettings& settings) { auto lambdaArg = lambda.Args().Arg(0).Ptr(); YQL_LOG(TRACE) << "Push filter. Initial filter lambda: " << NCommon::ExprToPrettyString(ctx, lambda.Ref()); @@ -54,7 +54,11 @@ TMaybeNode MakePushdownPredicate(const TCoLambda& lambda, TExprContex NPushdown::CollectPredicates(optionalIf.Predicate(), predicateTree, TExprBase(lambdaArg), TExprBase(lambdaArg), settings); YQL_ENSURE(predicateTree.IsValid(), "Collected filter predicates are invalid"); - NPushdown::TPredicateNode predicateToPush = SplitForPartialPushdown(predicateTree, ctx, pos, settings); + return SplitForPartialPushdown(predicateTree, ctx, pos, settings); +} + +TMaybeNode MakePushdownPredicate(const TCoLambda& lambda, TExprContext& ctx, const TPositionHandle& pos, const TSettings& settings) { + NPushdown::TPredicateNode predicateToPush = MakePushdownNode(lambda, ctx, pos, settings); if (!predicateToPush.IsValid()) { return {}; } @@ -64,7 +68,7 @@ TMaybeNode MakePushdownPredicate(const TCoLambda& lambda, TExprContex .Args({"filter_row"}) .Body() .Apply(predicateToPush.ExprNode.Cast()) - .With(TExprBase(lambdaArg), "filter_row") + .With(lambda.Args().Arg(0), "filter_row") .Build() .Done(); // clang-format on diff --git a/ydb/library/yql/providers/common/pushdown/physical_opt.h b/ydb/library/yql/providers/common/pushdown/physical_opt.h index eb3ff99d1ec9..d185abe832e9 100644 --- a/ydb/library/yql/providers/common/pushdown/physical_opt.h +++ b/ydb/library/yql/providers/common/pushdown/physical_opt.h @@ -1,11 +1,15 @@ #pragma once +#include "predicate_node.h" + #include #include +#include #include namespace NYql::NPushdown { +NPushdown::TPredicateNode MakePushdownNode(const NNodes::TCoLambda& lambda, TExprContext& ctx, const TPositionHandle& pos, const TSettings& settings); NNodes::TMaybeNode MakePushdownPredicate(const NNodes::TCoLambda& lambda, TExprContext& ctx, const TPositionHandle& pos, const TSettings& settings); } // namespace NYql::NPushdown diff --git a/ydb/library/yql/providers/common/pushdown/predicate_node.cpp b/ydb/library/yql/providers/common/pushdown/predicate_node.cpp index 14c3cb77b9d9..6fed8221e3e5 100644 --- a/ydb/library/yql/providers/common/pushdown/predicate_node.cpp +++ b/ydb/library/yql/providers/common/pushdown/predicate_node.cpp @@ -30,6 +30,16 @@ bool TPredicateNode::IsValid() const { return res && ExprNode.IsValid(); } +bool TPredicateNode::IsEmpty() const { + if (!ExprNode || !IsValid()) { + return true; + } + if (const auto maybeBool = ExprNode.Maybe()) { + return TStringBuf(maybeBool.Cast().Literal()) == "true"sv; + } + return false; +} + void TPredicateNode::SetPredicates(const std::vector& predicates, TExprContext& ctx, TPositionHandle pos, EBoolOp op) { auto predicatesSize = predicates.size(); if (predicatesSize == 0) { diff --git a/ydb/library/yql/providers/common/pushdown/predicate_node.h b/ydb/library/yql/providers/common/pushdown/predicate_node.h index 63697848aa3a..60eba545012e 100644 --- a/ydb/library/yql/providers/common/pushdown/predicate_node.h +++ b/ydb/library/yql/providers/common/pushdown/predicate_node.h @@ -23,6 +23,7 @@ struct TPredicateNode { ~TPredicateNode(); bool IsValid() const; + bool IsEmpty() const; void SetPredicates(const std::vector& predicates, TExprContext& ctx, TPositionHandle pos, EBoolOp op); NNodes::TMaybeNode ExprNode; diff --git a/ydb/library/yql/providers/common/pushdown/settings.h b/ydb/library/yql/providers/common/pushdown/settings.h index 5bf83bc2ab13..518334b810fa 100644 --- a/ydb/library/yql/providers/common/pushdown/settings.h +++ b/ydb/library/yql/providers/common/pushdown/settings.h @@ -38,7 +38,9 @@ struct TSettings { // May be partially pushdowned as: // $A OR $C // In case of unsupported / complicated expressions $B and $D - SplitOrOperator = 1 << 22 + SplitOrOperator = 1 << 22, + ToBytesFromStringExpressions = 1 << 23, // ToBytes(string like) + FlatMapOverOptionals = 1 << 24 // FlatMap(Optional, Lmabda (T) -> Optional) }; explicit TSettings(NLog::EComponent logComponent) diff --git a/ydb/library/yql/providers/generic/connector/api/service/protos/connector.proto b/ydb/library/yql/providers/generic/connector/api/service/protos/connector.proto index bf8a0a6c2657..cfe607fd7909 100644 --- a/ydb/library/yql/providers/generic/connector/api/service/protos/connector.proto +++ b/ydb/library/yql/providers/generic/connector/api/service/protos/connector.proto @@ -352,6 +352,12 @@ message TExpression { TExpression else_expression = 3; } + // CAST($value AS $type) + message TCast { + TExpression value = 1; + string type = 2; + } + message TNull { } @@ -368,6 +374,8 @@ message TExpression { TCoalesce coalesce = 5; TIf if = 6; + + TCast cast = 7; } } diff --git a/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.cpp b/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.cpp index 028401af2a57..fbdb598591fb 100644 --- a/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.cpp +++ b/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.cpp @@ -2,6 +2,7 @@ #include #include +#include #include namespace NYql { @@ -26,16 +27,30 @@ namespace NYql { TString FormatIfExpression(const TExpression::TIf& sqlIf); namespace { - - bool SerializeMember(const TCoMember& member, TExpression* proto, const TCoArgument& arg, TStringBuilder& err) { - if (member.Struct().Raw() != arg.Raw()) { // member callable called not for lambda argument - err << "member callable called not for lambda argument"; + struct TSerializationContext { + const TCoArgument& Arg; + TStringBuilder& Err; + std::unordered_map LambdaArgs = {}; + }; + + bool SerializeMember(const TCoMember& member, TExpression* proto, TSerializationContext& ctx) { + if (member.Struct().Raw() != ctx.Arg.Raw()) { // member callable called not for lambda argument + ctx.Err << "member callable called not for lambda argument"; return false; } proto->set_column(member.Name().StringValue()); return true; } + bool SerializeLambdaArgument(const TExprBase& node, TExpression* proto, TSerializationContext& ctx) { + const auto it = ctx.LambdaArgs.find(node.Raw()); + if (it == ctx.LambdaArgs.end()) { // node is not lambda argument + return false; + } + *proto = it->second; + return true; + } + template T Cast(const TStringBuf& from) { return FromString(from); @@ -47,6 +62,71 @@ namespace NYql { return TString(from); } + bool SerializeExpression(const TExprBase& expression, TExpression* proto, TSerializationContext& ctx, ui64 depth); + + bool SerializeCastExpression(const TCoSafeCast& safeCast, TExpression* proto, TSerializationContext& ctx, ui64 depth) { + const auto typeAnnotation = safeCast.Type().Ref().GetTypeAnn(); + if (!typeAnnotation) { + ctx.Err << "expected non empty type annotation for safe cast"; + return false; + } + + auto* dstProto = proto->mutable_cast(); + dstProto->set_type(FormatType(typeAnnotation->Cast()->GetType())); + return SerializeExpression(TExprBase(safeCast.Value()), dstProto->mutable_value(), ctx, depth + 1); + } + + bool SerializeToBytesExpression(const TExprBase& toBytes, TExpression* proto, TSerializationContext& ctx, ui64 depth) { + if (toBytes.Ref().ChildrenSize() != 1) { + ctx.Err << "invalid ToBytes expression, expected 1 child but got " << toBytes.Ref().ChildrenSize(); + return false; + } + + const auto toBytexExpr = TExprBase(toBytes.Ref().Child(0)); + auto typeAnnotation = toBytexExpr.Ref().GetTypeAnn(); + if (!typeAnnotation) { + ctx.Err << "expected non empty type annotation for ToBytes"; + return false; + } + if (typeAnnotation->GetKind() == ETypeAnnotationKind::Optional) { + typeAnnotation = typeAnnotation->Cast()->GetItemType(); + } + if (typeAnnotation->GetKind() != ETypeAnnotationKind::Data) { + ctx.Err << "expected data type or optional from data type in ToBytes"; + return false; + } + + const auto dataSlot = typeAnnotation->Cast()->GetSlot(); + if (!IsDataTypeString(dataSlot) && dataSlot != NUdf::EDataSlot::JsonDocument) { + ctx.Err << "expected only string like input type for ToBytes"; + return false; + } + + auto* dstProto = proto->mutable_cast(); + dstProto->set_type("String"); + return SerializeExpression(toBytexExpr, dstProto->mutable_value(), ctx, depth + 1); + } + + bool SerializeFlatMap(const TCoFlatMap& flatMap, TExpression* proto, TSerializationContext& ctx, ui64 depth) { + const auto lambda = flatMap.Lambda(); + const auto lambdaArgs = lambda.Args(); + if (lambdaArgs.Size() != 1) { + ctx.Err << "expected only one argument for flat map lambda"; + return false; + } + + auto* dstProto = proto->mutable_if_(); + dstProto->mutable_else_expression()->mutable_null(); + auto* dstInput = dstProto->mutable_predicate()->mutable_is_not_null()->mutable_value(); + if (!SerializeExpression(flatMap.Input(), dstInput, ctx, depth + 1)) { + return false; + } + + // Duplicated arguments is ok, maybe one lambda was used twice + ctx.LambdaArgs.insert({lambdaArgs.Ref().Child(0), *dstInput}); + return SerializeExpression(lambda.Body(), dstProto->mutable_then_expression(), ctx, depth + 1); + } + #define MATCH_ATOM(AtomType, ATOM_ENUM, proto_name, cpp_type) \ if (auto atom = expression.Maybe()) { \ auto* value = proto->mutable_typed_value(); \ @@ -62,25 +142,34 @@ namespace NYql { auto expr = maybeExpr.Cast(); \ auto* exprProto = proto->mutable_arithmetical_expression(); \ exprProto->set_operation(TExpression::TArithmeticalExpression::OP_ENUM); \ - return SerializeExpression(expr.Left(), exprProto->mutable_left_value(), arg, err, depth + 1) && SerializeExpression(expr.Right(), exprProto->mutable_right_value(), arg, err, depth + 1); \ + return SerializeExpression(expr.Left(), exprProto->mutable_left_value(), ctx, depth + 1) && SerializeExpression(expr.Right(), exprProto->mutable_right_value(), ctx, depth + 1); \ } - bool SerializeSqlIfExpression(const TCoIf& sqlIf, TExpression* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth); + bool SerializeSqlIfExpression(const TCoIf& sqlIf, TExpression* proto, TSerializationContext& ctx, ui64 depth); - bool SerializeCoalesceExpression(const TCoCoalesce& coalesce, TExpression* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth); + bool SerializeCoalesceExpression(const TCoCoalesce& coalesce, TExpression* proto, TSerializationContext& ctx, ui64 depth); - bool SerializeExpression(const TExprBase& expression, TExpression* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializeExpression(const TExprBase& expression, TExpression* proto, TSerializationContext& ctx, ui64 depth) { if (auto member = expression.Maybe()) { - return SerializeMember(member.Cast(), proto, arg, err); + return SerializeMember(member.Cast(), proto, ctx); } if (auto coalesce = expression.Maybe()) { - return SerializeCoalesceExpression(coalesce.Cast(), proto, arg, err, depth); + return SerializeCoalesceExpression(coalesce.Cast(), proto, ctx, depth); } if (auto sqlIf = expression.Maybe()) { - return SerializeSqlIfExpression(sqlIf.Cast(), proto, arg, err, depth); + return SerializeSqlIfExpression(sqlIf.Cast(), proto, ctx, depth); } if (auto just = expression.Maybe()) { - return SerializeExpression(TExprBase(just.Cast().Input()), proto, arg, err, depth + 1); + return SerializeExpression(TExprBase(just.Cast().Input()), proto, ctx, depth + 1); + } + if (auto safeCast = expression.Maybe()) { + return SerializeCastExpression(safeCast.Cast(), proto, ctx, depth); + } + if (expression.Ref().IsCallable("ToBytes")) { + return SerializeToBytesExpression(expression, proto, ctx, depth); + } + if (auto flatMap = expression.Maybe()) { + return SerializeFlatMap(flatMap.Cast(), proto, ctx, depth); } // data @@ -109,7 +198,12 @@ namespace NYql { return true; } - err << "unknown expression: " << expression.Raw()->Content(); + // Try to serialize as lambda argument + if (SerializeLambdaArgument(expression, proto, ctx)) { + return true; + } + + ctx.Err << "unknown expression: " << expression.Raw()->Content(); return false; } @@ -122,7 +216,7 @@ namespace NYql { proto->set_operation(TPredicate::TComparison::COMPARE_TYPE); \ } - bool SerializeCompare(const TCoCompare& compare, TPredicate* predicateProto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializeCompare(const TCoCompare& compare, TPredicate* predicateProto, TSerializationContext& ctx, ui64 depth) { TPredicate::TComparison* proto = predicateProto->mutable_comparison(); bool opMatched = false; @@ -136,28 +230,28 @@ namespace NYql { EXPR_NODE_TO_COMPARE_TYPE(TCoAggrNotEqual, ID); if (proto->operation() == TPredicate::TComparison::COMPARISON_OPERATION_UNSPECIFIED) { - err << "unknown compare operation: " << compare.Raw()->Content(); + ctx.Err << "unknown compare operation: " << compare.Raw()->Content(); return false; } - return SerializeExpression(compare.Left(), proto->mutable_left_value(), arg, err, depth + 1) && SerializeExpression(compare.Right(), proto->mutable_right_value(), arg, err, depth + 1); + return SerializeExpression(compare.Left(), proto->mutable_left_value(), ctx, depth + 1) && SerializeExpression(compare.Right(), proto->mutable_right_value(), ctx, depth + 1); } #undef EXPR_NODE_TO_COMPARE_TYPE - bool SerializePredicate(const TExprBase& predicate, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth); + bool SerializePredicate(const TExprBase& predicate, TPredicate* proto, TSerializationContext& ctx, ui64 depth); - bool SerializeSqlIfExpression(const TCoIf& sqlIf, TExpression* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializeSqlIfExpression(const TCoIf& sqlIf, TExpression* proto, TSerializationContext& ctx, ui64 depth) { auto* dstProto = proto->mutable_if_(); - return SerializePredicate(TExprBase(sqlIf.Predicate()), dstProto->mutable_predicate(), arg, err, depth + 1) - && SerializeExpression(TExprBase(sqlIf.ThenValue()), dstProto->mutable_then_expression(), arg, err, depth + 1) - && SerializeExpression(TExprBase(sqlIf.ElseValue()), dstProto->mutable_else_expression(), arg, err, depth + 1); + return SerializePredicate(TExprBase(sqlIf.Predicate()), dstProto->mutable_predicate(), ctx, depth + 1) + && SerializeExpression(TExprBase(sqlIf.ThenValue()), dstProto->mutable_then_expression(), ctx, depth + 1) + && SerializeExpression(TExprBase(sqlIf.ElseValue()), dstProto->mutable_else_expression(), ctx, depth + 1); } - bool SerializeSqlIfPredicate(const TCoIf& sqlIf, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializeSqlIfPredicate(const TCoIf& sqlIf, TPredicate* proto, TSerializationContext& ctx, ui64 depth) { auto* dstProto = proto->mutable_if_(); - return SerializePredicate(TExprBase(sqlIf.Predicate()), dstProto->mutable_predicate(), arg, err, depth + 1) - && SerializePredicate(TExprBase(sqlIf.ThenValue()), dstProto->mutable_then_predicate(), arg, err, depth + 1) - && SerializePredicate(TExprBase(sqlIf.ElseValue()), dstProto->mutable_else_predicate(), arg, err, depth + 1); + return SerializePredicate(TExprBase(sqlIf.Predicate()), dstProto->mutable_predicate(), ctx, depth + 1) + && SerializePredicate(TExprBase(sqlIf.ThenValue()), dstProto->mutable_then_predicate(), ctx, depth + 1) + && SerializePredicate(TExprBase(sqlIf.ElseValue()), dstProto->mutable_else_predicate(), ctx, depth + 1); } template @@ -171,10 +265,10 @@ namespace NYql { } } - bool SerializeCoalesceExpression(const TCoCoalesce& coalesce, TExpression* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializeCoalesceExpression(const TCoCoalesce& coalesce, TExpression* proto, TSerializationContext& ctx, ui64 depth) { auto* dstProto = proto->mutable_coalesce(); for (const auto& child : coalesce.Ptr()->Children()) { - if (!SerializeExpression(TExprBase(child), dstProto->add_operands(), arg, err, depth + 1)) { + if (!SerializeExpression(TExprBase(child), dstProto->add_operands(), ctx, depth + 1)) { return false; } UnwrapNestedCoalesce(dstProto); @@ -182,19 +276,19 @@ namespace NYql { return true; } - bool SerializeCoalescePredicate(const TCoCoalesce& coalesce, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializeCoalescePredicate(const TCoCoalesce& coalesce, TPredicate* proto, TSerializationContext& ctx, ui64 depth) { // Special case for top level COALESCE: COALESCE(Predicat, FALSE) // We can assume NULL as FALSE and skip COALESCE if (depth == 0) { auto value = coalesce.Value().Maybe(); if (value && TStringBuf(value.Cast().Literal()) == "false"sv) { - return SerializePredicate(TExprBase(coalesce.Predicate()), proto, arg, err, 0); + return SerializePredicate(TExprBase(coalesce.Predicate()), proto, ctx, 0); } } auto* dstProto = proto->mutable_coalesce(); for (const auto& child : coalesce.Ptr()->Children()) { - if (!SerializePredicate(TExprBase(child), dstProto->add_operands(), arg, err, depth + 1)) { + if (!SerializePredicate(TExprBase(child), dstProto->add_operands(), ctx, depth + 1)) { return false; } UnwrapNestedCoalesce(dstProto); @@ -202,18 +296,18 @@ namespace NYql { return true; } - bool SerializeExists(const TCoExists& exists, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err, bool withNot, ui64 depth) { + bool SerializeExists(const TCoExists& exists, TPredicate* proto, TSerializationContext& ctx, bool withNot, ui64 depth) { auto* expressionProto = withNot ? proto->mutable_is_null()->mutable_value() : proto->mutable_is_not_null()->mutable_value(); - return SerializeExpression(exists.Optional(), expressionProto, arg, err, depth + 1); + return SerializeExpression(exists.Optional(), expressionProto, ctx, depth + 1); } - bool SerializeSqlIn(const TCoSqlIn& sqlIn, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializeSqlIn(const TCoSqlIn& sqlIn, TPredicate* proto, TSerializationContext& ctx, ui64 depth) { auto* dstProto = proto->mutable_in(); const TExprBase& expr = sqlIn.Collection(); const TExprBase& lookup = sqlIn.Lookup(); auto* expressionProto = dstProto->mutable_value(); - SerializeExpression(lookup, expressionProto, arg, err, depth + 1); + SerializeExpression(lookup, expressionProto, ctx, depth + 1); TExprNode::TPtr collection; if (expr.Ref().IsList()) { @@ -221,145 +315,145 @@ namespace NYql { } else if (auto maybeAsList = expr.Maybe()) { collection = maybeAsList.Cast().Ptr(); } else { - err << "unknown source for in: " << expr.Ref().Content(); + ctx.Err << "unknown source for in: " << expr.Ref().Content(); return false; } for (auto& child : collection->Children()) { - if (!SerializeExpression(TExprBase(child), dstProto->add_set(), arg, err, depth + 1)) { + if (!SerializeExpression(TExprBase(child), dstProto->add_set(), ctx, depth + 1)) { return false; } } return true; } - bool SerializeIsNotDistinctFrom(const TExprBase& predicate, TPredicate* predicateProto, const TCoArgument& arg, TStringBuilder& err, bool invert, ui64 depth) { + bool SerializeIsNotDistinctFrom(const TExprBase& predicate, TPredicate* predicateProto, TSerializationContext& ctx, bool invert, ui64 depth) { if (predicate.Ref().ChildrenSize() != 2) { - err << "invalid IsNotDistinctFrom predicate, expected 2 children but got " << predicate.Ref().ChildrenSize(); + ctx.Err << "invalid IsNotDistinctFrom predicate, expected 2 children but got " << predicate.Ref().ChildrenSize(); return false; } TPredicate::TComparison* proto = predicateProto->mutable_comparison(); proto->set_operation(!invert ? TPredicate::TComparison::IND : TPredicate::TComparison::ID); - return SerializeExpression(TExprBase(predicate.Ref().Child(0)), proto->mutable_left_value(), arg, err, depth + 1) - && SerializeExpression(TExprBase(predicate.Ref().Child(1)), proto->mutable_right_value(), arg, err, depth + 1); + return SerializeExpression(TExprBase(predicate.Ref().Child(0)), proto->mutable_left_value(), ctx, depth + 1) + && SerializeExpression(TExprBase(predicate.Ref().Child(1)), proto->mutable_right_value(), ctx, depth + 1); } - bool SerializeAnd(const TCoAnd& andExpr, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializeAnd(const TCoAnd& andExpr, TPredicate* proto, TSerializationContext& ctx, ui64 depth) { auto* dstProto = proto->mutable_conjunction(); for (const auto& child : andExpr.Ptr()->Children()) { - if (!SerializePredicate(TExprBase(child), dstProto->add_operands(), arg, err, depth + 1)) { + if (!SerializePredicate(TExprBase(child), dstProto->add_operands(), ctx, depth + 1)) { return false; } } return true; } - bool SerializeOr(const TCoOr& orExpr, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializeOr(const TCoOr& orExpr, TPredicate* proto, TSerializationContext& ctx, ui64 depth) { auto* dstProto = proto->mutable_disjunction(); for (const auto& child : orExpr.Ptr()->Children()) { - if (!SerializePredicate(TExprBase(child), dstProto->add_operands(), arg, err, depth + 1)) { + if (!SerializePredicate(TExprBase(child), dstProto->add_operands(), ctx, depth + 1)) { return false; } } return true; } - bool SerializeNot(const TCoNot& notExpr, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializeNot(const TCoNot& notExpr, TPredicate* proto, TSerializationContext& ctx, ui64 depth) { // Special case: (Not (Exists ...)) if (auto exists = notExpr.Value().Maybe()) { - return SerializeExists(exists.Cast(), proto, arg, err, true, depth + 1); + return SerializeExists(exists.Cast(), proto, ctx, true, depth + 1); } auto* dstProto = proto->mutable_negation(); - return SerializePredicate(notExpr.Value(), dstProto->mutable_operand(), arg, err, depth + 1); + return SerializePredicate(notExpr.Value(), dstProto->mutable_operand(), ctx, depth + 1); } - bool SerializeMember(const TCoMember& member, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err) { - return SerializeMember(member, proto->mutable_bool_expression()->mutable_value(), arg, err); + bool SerializeMember(const TCoMember& member, TPredicate* proto, TSerializationContext& ctx) { + return SerializeMember(member, proto->mutable_bool_expression()->mutable_value(), ctx); } - bool SerializeRegexp(const TCoUdf& regexp, const TExprNode::TListType& children, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializeRegexp(const TCoUdf& regexp, const TExprNode::TListType& children, TPredicate* proto, TSerializationContext& ctx, ui64 depth) { if (children.size() != 2) { - err << "expected exactly one argument for UDF Re2.Grep, but got: " << children.size() - 1; + ctx.Err << "expected exactly one argument for UDF function Re2.Grep, but got: " << children.size() - 1; return false; } const auto& maybeRunConfig = regexp.RunConfigValue(); if (!maybeRunConfig) { - err << "predicate for REGEXP can't be empty"; + ctx.Err << "predicate for REGEXP can't be empty"; return false; } const auto& runConfig = maybeRunConfig.Cast().Ref(); if (runConfig.ChildrenSize() != 2) { - err << "expected exactly two run config options for UDF Re2.Grep, but got: " << runConfig.ChildrenSize(); + ctx.Err << "expected exactly two run config options for UDF Re2.Grep, but got: " << runConfig.ChildrenSize(); return false; } auto* dstProto = proto->mutable_regexp(); - return SerializeExpression(TExprBase(runConfig.ChildPtr(0)), dstProto->mutable_pattern(), arg, err, depth + 1) - && SerializeExpression(TExprBase(children[1]), dstProto->mutable_value(), arg, err, depth + 1); + return SerializeExpression(TExprBase(runConfig.ChildPtr(0)), dstProto->mutable_pattern(), ctx, depth + 1) + && SerializeExpression(TExprBase(children[1]), dstProto->mutable_value(), ctx, depth + 1); } - bool SerializeApply(const TCoApply& apply, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializeApply(const TCoApply& apply, TPredicate* proto, TSerializationContext& ctx, ui64 depth) { const auto& maybeUdf = apply.Callable().Maybe(); if (!maybeUdf) { - err << "expected only UDF apply, but got: " << apply.Callable().Ref().Content(); + ctx.Err << "expected only UDF apply, but got: " << apply.Callable().Ref().Content(); return false; } const auto& udf = maybeUdf.Cast(); if (TStringBuf(udf.MethodName()) == "Re2.Grep"sv) { - return SerializeRegexp(udf, apply.Ref().ChildrenList(), proto, arg, err, depth); + return SerializeRegexp(udf, apply.Ref().ChildrenList(), proto, ctx, depth); } - err << "unknown UDF in apply: " << TStringBuf(udf.MethodName()); + ctx.Err << "unknown UDF in apply: " << TStringBuf(udf.MethodName()); return false; } - bool SerializePredicate(const TExprBase& predicate, TPredicate* proto, const TCoArgument& arg, TStringBuilder& err, ui64 depth) { + bool SerializePredicate(const TExprBase& predicate, TPredicate* proto, TSerializationContext& ctx, ui64 depth) { if (auto compare = predicate.Maybe()) { - return SerializeCompare(compare.Cast(), proto, arg, err, depth); + return SerializeCompare(compare.Cast(), proto, ctx, depth); } if (auto coalesce = predicate.Maybe()) { - return SerializeCoalescePredicate(coalesce.Cast(), proto, arg, err, depth); + return SerializeCoalescePredicate(coalesce.Cast(), proto, ctx, depth); } if (auto andExpr = predicate.Maybe()) { - return SerializeAnd(andExpr.Cast(), proto, arg, err, depth); + return SerializeAnd(andExpr.Cast(), proto, ctx, depth); } if (auto orExpr = predicate.Maybe()) { - return SerializeOr(orExpr.Cast(), proto, arg, err, depth); + return SerializeOr(orExpr.Cast(), proto, ctx, depth); } if (auto notExpr = predicate.Maybe()) { - return SerializeNot(notExpr.Cast(), proto, arg, err, depth); + return SerializeNot(notExpr.Cast(), proto, ctx, depth); } if (auto member = predicate.Maybe()) { - return SerializeMember(member.Cast(), proto, arg, err); + return SerializeMember(member.Cast(), proto, ctx); } if (auto exists = predicate.Maybe()) { - return SerializeExists(exists.Cast(), proto, arg, err, false, depth); + return SerializeExists(exists.Cast(), proto, ctx, false, depth); } if (auto sqlIn = predicate.Maybe()) { - return SerializeSqlIn(sqlIn.Cast(), proto, arg, err, depth); + return SerializeSqlIn(sqlIn.Cast(), proto, ctx, depth); } if (predicate.Ref().IsCallable("IsNotDistinctFrom")) { - return SerializeIsNotDistinctFrom(predicate, proto, arg, err, false, depth); + return SerializeIsNotDistinctFrom(predicate, proto, ctx, false, depth); } if (predicate.Ref().IsCallable("IsDistinctFrom")) { - return SerializeIsNotDistinctFrom(predicate, proto, arg, err, true, depth); + return SerializeIsNotDistinctFrom(predicate, proto, ctx, true, depth); } if (auto sqlIf = predicate.Maybe()) { - return SerializeSqlIfPredicate(sqlIf.Cast(), proto, arg, err, depth); + return SerializeSqlIfPredicate(sqlIf.Cast(), proto, ctx, depth); } if (auto just = predicate.Maybe()) { - return SerializePredicate(TExprBase(just.Cast().Input()), proto, arg, err, depth + 1); + return SerializePredicate(TExprBase(just.Cast().Input()), proto, ctx, depth + 1); } if (auto apply = predicate.Maybe()) { - return SerializeApply(apply.Cast(), proto, arg, err, depth); + return SerializeApply(apply.Cast(), proto, ctx, depth); } // Try to serialize predicate as boolean expression // For example single bool value TRUE in COALESCE or IF - return SerializeExpression(predicate, proto->mutable_bool_expression()->mutable_value(), arg, err, depth); + return SerializeExpression(predicate, proto->mutable_bool_expression()->mutable_value(), ctx, depth); } } @@ -396,6 +490,11 @@ namespace NYql { return "NULL"; } + TString FormatCast(const TExpression::TCast& cast) { + auto value = FormatExpression(cast.value()); + return TStringBuilder() << "CAST(" << value << " AS " << cast.type() << ")"; + } + TString FormatExpression(const TExpression& expression) { switch (expression.payload_case()) { case TExpression::kColumn: @@ -410,8 +509,10 @@ namespace NYql { return FormatCoalesce(expression.coalesce()); case TExpression::kIf: return FormatIfExpression(expression.if_()); + case TExpression::kCast: + return FormatCast(expression.cast()); default: - throw yexception() << "UnimplementedExpression, payload_case " << static_cast(expression.payload_case()); + throw yexception() << "Failed to format expression, unimplemented payload_case " << static_cast(expression.payload_case()); } } @@ -697,8 +798,13 @@ namespace NYql { return TStringBuf(maybeBool.Cast().Literal()) == "true"sv; } + bool SerializeFilterPredicate(const TExprBase& predicateBody, const TCoArgument& predicateArgument, NConnector::NApi::TPredicate* proto, TStringBuilder& err) { + TSerializationContext ctx = {.Arg = predicateArgument, .Err = err}; + return SerializePredicate(predicateBody, proto, ctx, 0); + } + bool SerializeFilterPredicate(const TCoLambda& predicate, TPredicate* proto, TStringBuilder& err) { - return SerializePredicate(predicate.Body(), proto, predicate.Args().Arg(0), err, 0); + return SerializeFilterPredicate(predicate.Body(), predicate.Args().Arg(0), proto, err); } TString FormatWhere(const TPredicate& predicate) { diff --git a/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.h b/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.h index b798e483b8a5..b9c2f06fe4da 100644 --- a/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.h +++ b/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.h @@ -9,6 +9,7 @@ namespace NYql::NConnector::NApi { namespace NYql { bool IsEmptyFilterPredicate(const NNodes::TCoLambda& lambda); + bool SerializeFilterPredicate(const NNodes::TExprBase& predicateBody, const NNodes::TCoArgument& predicateArgument, NConnector::NApi::TPredicate* proto, TStringBuilder& err); bool SerializeFilterPredicate(const NNodes::TCoLambda& predicate, NConnector::NApi::TPredicate* proto, TStringBuilder& err); TString FormatWhere(const NConnector::NApi::TPredicate& predicate); } // namespace NYql diff --git a/ydb/library/yql/providers/generic/pushdown/yql_generic_match_predicate.cpp b/ydb/library/yql/providers/generic/pushdown/yql_generic_match_predicate.cpp index 5b7058c4e1d8..6f3d7fb03ee3 100644 --- a/ydb/library/yql/providers/generic/pushdown/yql_generic_match_predicate.cpp +++ b/ydb/library/yql/providers/generic/pushdown/yql_generic_match_predicate.cpp @@ -55,6 +55,7 @@ namespace NYql::NGenericPushDown { case NYql::NConnector::NApi::TExpression::kNull: case NYql::NConnector::NApi::TExpression::kCoalesce: case NYql::NConnector::NApi::TExpression::kIf: + case NYql::NConnector::NApi::TExpression::kCast: case NYql::NConnector::NApi::TExpression::PAYLOAD_NOT_SET: return false; } @@ -70,6 +71,7 @@ namespace NYql::NGenericPushDown { case NYql::NConnector::NApi::TExpression::kNull: case NYql::NConnector::NApi::TExpression::kCoalesce: case NYql::NConnector::NApi::TExpression::kIf: + case NYql::NConnector::NApi::TExpression::kCast: case NYql::NConnector::NApi::TExpression::PAYLOAD_NOT_SET: return false; } @@ -281,6 +283,7 @@ namespace NYql::NGenericPushDown { case NYql::NConnector::NApi::TExpression::kNull: case NYql::NConnector::NApi::TExpression::kCoalesce: case NYql::NConnector::NApi::TExpression::kIf: + case NYql::NConnector::NApi::TExpression::kCast: case NYql::NConnector::NApi::TExpression::PAYLOAD_NOT_SET: return Triple::Unknown; } diff --git a/ydb/library/yql/providers/pq/provider/yql_pq_logical_opt.cpp b/ydb/library/yql/providers/pq/provider/yql_pq_logical_opt.cpp index 774feeb48c91..60a541aa94c9 100644 --- a/ydb/library/yql/providers/pq/provider/yql_pq_logical_opt.cpp +++ b/ydb/library/yql/providers/pq/provider/yql_pq_logical_opt.cpp @@ -34,7 +34,8 @@ namespace { // Operator features EFlag::ExpressionAsPredicate | EFlag::ArithmeticalExpressions | EFlag::ImplicitConversionToInt64 | EFlag::StringTypes | EFlag::LikeOperator | EFlag::DoNotCheckCompareArgumentsTypes | EFlag::InOperator | - EFlag::IsDistinctOperator | EFlag::JustPassthroughOperators | DivisionExpressions | + EFlag::IsDistinctOperator | EFlag::JustPassthroughOperators | DivisionExpressions | EFlag::CastExpression | + EFlag::ToBytesFromStringExpressions | EFlag::FlatMapOverOptionals | // Split features EFlag::SplitOrOperator @@ -267,19 +268,14 @@ class TPqLogicalOptProposalTransformer : public TOptimizeTransformerBase { return node; } - auto newFilterLambda = MakePushdownPredicate(flatmap.Lambda(), ctx, node.Pos(), TPushdownSettings()); - if (!newFilterLambda) { - return node; - } - - auto predicate = newFilterLambda.Cast(); - if (NYql::IsEmptyFilterPredicate(predicate)) { + NPushdown::TPredicateNode predicate = MakePushdownNode(flatmap.Lambda(), ctx, node.Pos(), TPushdownSettings()); + if (predicate.IsEmpty()) { return node; } TStringBuilder err; NYql::NConnector::NApi::TPredicate predicateProto; - if (!NYql::SerializeFilterPredicate(predicate, &predicateProto, err)) { + if (!NYql::SerializeFilterPredicate(predicate.ExprNode.Cast(), flatmap.Lambda().Args().Arg(0), &predicateProto, err)) { ctx.AddWarning(TIssue(ctx.GetPosition(node.Pos()), "Failed to serialize filter predicate for source: " + err)); return node; } diff --git a/ydb/tests/fq/yds/test_row_dispatcher.py b/ydb/tests/fq/yds/test_row_dispatcher.py index 94e31a4f7983..136908c1010c 100644 --- a/ydb/tests/fq/yds/test_row_dispatcher.py +++ b/ydb/tests/fq/yds/test_row_dispatcher.py @@ -308,10 +308,10 @@ def test_filters_non_optional_field(self, kikimr, client): sql = Rf''' INSERT INTO {YDS_CONNECTION}.`{self.output_topic}` SELECT Cast(time as String) FROM {YDS_CONNECTION}.`{self.input_topic}` - WITH (format=json_each_row, SCHEMA (time UInt64 NOT NULL, data String NOT NULL, event String NOT NULL)) WHERE ''' + WITH (format=json_each_row, SCHEMA (time UInt64 NOT NULL, data String NOT NULL, event String NOT NULL, nested Json NOT NULL)) WHERE ''' data = [ - '{"time": 101, "data": "hello1", "event": "event1"}', - '{"time": 102, "data": "hello2", "event": "event2"}'] + '{"time": 101, "data": "hello1", "event": "event1", "nested": {"xyz": "key"}}', + '{"time": 102, "data": "hello2", "event": "event2", "nested": ["abc", "key"]}'] filter = "time > 101;" expected = ['102'] self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (`time` > 101)') @@ -331,6 +331,10 @@ def test_filters_non_optional_field(self, kikimr, client): self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE ((`event` IS DISTINCT FROM `data`) AND (`event` IN (\\"1\\"') filter = ' IF(event = "event2", event IS DISTINCT FROM data, FALSE)' self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE IF((`event` = \\"event2\\"), (`event` IS DISTINCT FROM `data`), FALSE)') + filter = ' nested REGEXP ".*abc.*"' + self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (CAST(`nested` AS String) REGEXP ".*abc.*")') + filter = ' CAST(nested AS String) REGEXP ".*abc.*"' + self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (CAST(`nested` AS String) REGEXP ".*abc.*")') @yq_v1 def test_filters_optional_field(self, kikimr, client): @@ -342,10 +346,10 @@ def test_filters_optional_field(self, kikimr, client): sql = Rf''' INSERT INTO {YDS_CONNECTION}.`{self.output_topic}` SELECT Cast(time as String) FROM {YDS_CONNECTION}.`{self.input_topic}` - WITH (format=json_each_row, SCHEMA (time UInt64 NOT NULL, data String, event String, flag Bool, field1 UInt8, field2 Int64)) WHERE ''' + WITH (format=json_each_row, SCHEMA (time UInt64 NOT NULL, data String, event String, flag Bool, field1 UInt8, field2 Int64, nested Json)) WHERE ''' data = [ - '{"time": 101, "data": "hello1", "event": "event1", "flag": false, "field1": 5, "field2": 5}', - '{"time": 102, "data": "hello2", "event": "event2", "flag": true, "field1": 5, "field2": 1005}'] + '{"time": 101, "data": "hello1", "event": "event1", "flag": false, "field1": 5, "field2": 5, "nested": {"xyz": "key"}}', + '{"time": 102, "data": "hello2", "event": "event2", "flag": true, "field1": 5, "field2": 1005, "nested": ["abc", "key"]}'] expected = ['102'] filter = 'data = "hello2"' self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (`data` = \\"hello2\\")') @@ -381,6 +385,10 @@ def test_filters_optional_field(self, kikimr, client): self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (NOT (COALESCE(`event`, \\"\\") REGEXP \\"e.*e.*t1\\"))') filter = " event ?? '' REGEXP data ?? '' OR time = 102" self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE ((COALESCE(`event`, \\"\\") REGEXP COALESCE(`data`, \\"\\")) OR (`time` = 102))') + filter = ' nested REGEXP ".*abc.*"' + self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (IF((`nested` IS NOT NULL), CAST(`nested` AS String), NULL) REGEXP ".*abc.*")') + filter = ' CAST(nested AS String) REGEXP ".*abc.*"' + self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (CAST(`nested` AS Optional) REGEXP ".*abc.*")') @yq_v1 def test_filter_missing_fields(self, kikimr, client): From a7824a5806ec44140ea99129eea9dd6812a941aa Mon Sep 17 00:00:00 2001 From: Grigoriy Pisarenko Date: Wed, 27 Nov 2024 17:27:58 +0000 Subject: [PATCH 2/4] Fixed integration tests --- ydb/tests/fq/yds/test_row_dispatcher.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ydb/tests/fq/yds/test_row_dispatcher.py b/ydb/tests/fq/yds/test_row_dispatcher.py index 136908c1010c..e1f33385b088 100644 --- a/ydb/tests/fq/yds/test_row_dispatcher.py +++ b/ydb/tests/fq/yds/test_row_dispatcher.py @@ -332,9 +332,9 @@ def test_filters_non_optional_field(self, kikimr, client): filter = ' IF(event = "event2", event IS DISTINCT FROM data, FALSE)' self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE IF((`event` = \\"event2\\"), (`event` IS DISTINCT FROM `data`), FALSE)') filter = ' nested REGEXP ".*abc.*"' - self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (CAST(`nested` AS String) REGEXP ".*abc.*")') + self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (CAST(`nested` AS String) REGEXP \\".*abc.*\\")') filter = ' CAST(nested AS String) REGEXP ".*abc.*"' - self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (CAST(`nested` AS String) REGEXP ".*abc.*")') + self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (CAST(`nested` AS String) REGEXP \\".*abc.*\\")') @yq_v1 def test_filters_optional_field(self, kikimr, client): @@ -386,9 +386,9 @@ def test_filters_optional_field(self, kikimr, client): filter = " event ?? '' REGEXP data ?? '' OR time = 102" self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE ((COALESCE(`event`, \\"\\") REGEXP COALESCE(`data`, \\"\\")) OR (`time` = 102))') filter = ' nested REGEXP ".*abc.*"' - self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (IF((`nested` IS NOT NULL), CAST(`nested` AS String), NULL) REGEXP ".*abc.*")') + self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (IF((`nested` IS NOT NULL), CAST(`nested` AS String), NULL) REGEXP \\".*abc.*\\")') filter = ' CAST(nested AS String) REGEXP ".*abc.*"' - self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (CAST(`nested` AS Optional) REGEXP ".*abc.*")') + self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (CAST(`nested` AS Optional) REGEXP \\".*abc.*\\")') @yq_v1 def test_filter_missing_fields(self, kikimr, client): From 623f690220df0bf57e4218cfc5288927fdff6633 Mon Sep 17 00:00:00 2001 From: Grigoriy Pisarenko Date: Thu, 28 Nov 2024 10:20:53 +0000 Subject: [PATCH 3/4] Fixed type passing in proto --- .../api/service/protos/connector.proto | 2 +- .../yql_generic_predicate_pushdown.cpp | 102 ++++++++++++++++-- 2 files changed, 97 insertions(+), 7 deletions(-) diff --git a/ydb/library/yql/providers/generic/connector/api/service/protos/connector.proto b/ydb/library/yql/providers/generic/connector/api/service/protos/connector.proto index cfe607fd7909..1d638fab1238 100644 --- a/ydb/library/yql/providers/generic/connector/api/service/protos/connector.proto +++ b/ydb/library/yql/providers/generic/connector/api/service/protos/connector.proto @@ -355,7 +355,7 @@ message TExpression { // CAST($value AS $type) message TCast { TExpression value = 1; - string type = 2; + Ydb.Type type = 2; } message TNull { diff --git a/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.cpp b/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.cpp index fbdb598591fb..f862bfe45a5b 100644 --- a/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.cpp +++ b/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.cpp @@ -64,18 +64,61 @@ namespace NYql { bool SerializeExpression(const TExprBase& expression, TExpression* proto, TSerializationContext& ctx, ui64 depth); +#define MATCH_TYPE(DataType, PROTO_TYPE) \ + if (dataSlot == NUdf::EDataSlot::DataType) { \ + dstType->set_type_id(Ydb::Type::PROTO_TYPE); \ + return true; \ + } + bool SerializeCastExpression(const TCoSafeCast& safeCast, TExpression* proto, TSerializationContext& ctx, ui64 depth) { const auto typeAnnotation = safeCast.Type().Ref().GetTypeAnn(); if (!typeAnnotation) { ctx.Err << "expected non empty type annotation for safe cast"; return false; } + if (typeAnnotation->GetKind() != ETypeAnnotationKind::Type) { + ctx.Err << "expected only type expression for safe cast"; + return false; + } auto* dstProto = proto->mutable_cast(); - dstProto->set_type(FormatType(typeAnnotation->Cast()->GetType())); - return SerializeExpression(TExprBase(safeCast.Value()), dstProto->mutable_value(), ctx, depth + 1); + if (!SerializeExpression(safeCast.Value(), dstProto->mutable_value(), ctx, depth + 1)) { + return false; + } + + auto type = typeAnnotation->Cast()->GetType(); + auto* dstType = dstProto->mutable_type(); + if (type->GetKind() == ETypeAnnotationKind::Optional) { + type = type->Cast()->GetItemType(); + dstType = dstType->mutable_optional_type()->mutable_item(); + } + if (type->GetKind() != ETypeAnnotationKind::Data) { + ctx.Err << "expected only data type or optional data type for safe cast"; + return false; + } + const auto dataSlot = type->Cast()->GetSlot(); + + MATCH_TYPE(Bool, BOOL); + MATCH_TYPE(Int8, INT8); + MATCH_TYPE(Int16, INT16); + MATCH_TYPE(Int32, INT32); + MATCH_TYPE(Int64, INT64); + MATCH_TYPE(Uint8, UINT8); + MATCH_TYPE(Uint16, UINT16); + MATCH_TYPE(Uint32, UINT32); + MATCH_TYPE(Uint64, UINT64); + MATCH_TYPE(Float, FLOAT); + MATCH_TYPE(Double, DOUBLE); + MATCH_TYPE(String, STRING); + MATCH_TYPE(Utf8, UTF8); + MATCH_TYPE(Json, JSON); + + ctx.Err << "unknown data slot " << static_cast(dataSlot) << " for safe cast"; + return false; } +#undef MATCH_TYPE + bool SerializeToBytesExpression(const TExprBase& toBytes, TExpression* proto, TSerializationContext& ctx, ui64 depth) { if (toBytes.Ref().ChildrenSize() != 1) { ctx.Err << "invalid ToBytes expression, expected 1 child but got " << toBytes.Ref().ChildrenSize(); @@ -103,7 +146,7 @@ namespace NYql { } auto* dstProto = proto->mutable_cast(); - dstProto->set_type("String"); + dstProto->mutable_type()->set_type_id(Ydb::Type::STRING); return SerializeExpression(toBytexExpr, dstProto->mutable_value(), ctx, depth + 1); } @@ -126,7 +169,7 @@ namespace NYql { ctx.LambdaArgs.insert({lambdaArgs.Ref().Child(0), *dstInput}); return SerializeExpression(lambda.Body(), dstProto->mutable_then_expression(), ctx, depth + 1); } - + #define MATCH_ATOM(AtomType, ATOM_ENUM, proto_name, cpp_type) \ if (auto atom = expression.Maybe()) { \ auto* value = proto->mutable_typed_value(); \ @@ -482,7 +525,53 @@ namespace NYql { case Ydb::Value::kTextValue: return NFq::EncloseAndEscapeString(value.value().text_value(), '"'); default: - throw yexception() << "ErrUnimplementedTypedValue, value case " << static_cast(value.value().value_case()); + throw yexception() << "Failed to format ydb typed vlaue, value case " << static_cast(value.value().value_case()) << " is not supported"; + } + } + + TString FormatPrimitiveType(const Ydb::Type::PrimitiveTypeId& typeId) { + switch (typeId) { + case Ydb::Type::BOOL: + return "Bool"; + case Ydb::Type::INT8: + return "Int8"; + case Ydb::Type::INT16: + return "Int16"; + case Ydb::Type::INT32: + return "Int32"; + case Ydb::Type::INT64: + return "Int64"; + case Ydb::Type::UINT8: + return "Uint8"; + case Ydb::Type::UINT16: + return "Uint16"; + case Ydb::Type::UINT32: + return "Uint32"; + case Ydb::Type::UINT64: + return "Uint64"; + case Ydb::Type::FLOAT: + return "Float"; + case Ydb::Type::DOUBLE: + return "Double"; + case Ydb::Type::STRING: + return "String"; + case Ydb::Type::UTF8: + return "Utf8"; + case Ydb::Type::JSON: + return "Json"; + default: + throw yexception() << "Failed to format primitive type, type case " << static_cast(typeId) << " is not supported"; + } + } + + TString FormatType(const Ydb::Type& type) { + switch (type.type_case()) { + case Ydb::Type::kTypeId: + return FormatPrimitiveType(type.type_id()); + case Ydb::Type::kOptionalType: + return TStringBuilder() << FormatType(type.optional_type().item()) << "?"; + default: + throw yexception() << "Failed to format ydb type, type id " << static_cast(type.type_case()) << " is not supported"; } } @@ -492,7 +581,8 @@ namespace NYql { TString FormatCast(const TExpression::TCast& cast) { auto value = FormatExpression(cast.value()); - return TStringBuilder() << "CAST(" << value << " AS " << cast.type() << ")"; + auto type = FormatType(cast.type()); + return TStringBuilder() << "CAST(" << value << " AS " << type << ")"; } TString FormatExpression(const TExpression& expression) { From ac7bd7bb8f860018fd26f1edfd09413d0718e1b4 Mon Sep 17 00:00:00 2001 From: Grigoriy Pisarenko Date: Fri, 29 Nov 2024 07:13:15 +0000 Subject: [PATCH 4/4] Fixed integration test --- ydb/tests/fq/yds/test_row_dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/tests/fq/yds/test_row_dispatcher.py b/ydb/tests/fq/yds/test_row_dispatcher.py index e1f33385b088..9dbe9603582d 100644 --- a/ydb/tests/fq/yds/test_row_dispatcher.py +++ b/ydb/tests/fq/yds/test_row_dispatcher.py @@ -388,7 +388,7 @@ def test_filters_optional_field(self, kikimr, client): filter = ' nested REGEXP ".*abc.*"' self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (IF((`nested` IS NOT NULL), CAST(`nested` AS String), NULL) REGEXP \\".*abc.*\\")') filter = ' CAST(nested AS String) REGEXP ".*abc.*"' - self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (CAST(`nested` AS Optional) REGEXP \\".*abc.*\\")') + self.run_and_check(kikimr, client, sql + filter, data, expected, 'predicate: WHERE (CAST(`nested` AS String?) REGEXP \\".*abc.*\\")') @yq_v1 def test_filter_missing_fields(self, kikimr, client):