From 623f690220df0bf57e4218cfc5288927fdff6633 Mon Sep 17 00:00:00 2001 From: Grigoriy Pisarenko Date: Thu, 28 Nov 2024 10:20:53 +0000 Subject: [PATCH] Fixed type passing in proto --- .../api/service/protos/connector.proto | 2 +- .../yql_generic_predicate_pushdown.cpp | 102 ++++++++++++++++-- 2 files changed, 97 insertions(+), 7 deletions(-) diff --git a/ydb/library/yql/providers/generic/connector/api/service/protos/connector.proto b/ydb/library/yql/providers/generic/connector/api/service/protos/connector.proto index cfe607fd7909..1d638fab1238 100644 --- a/ydb/library/yql/providers/generic/connector/api/service/protos/connector.proto +++ b/ydb/library/yql/providers/generic/connector/api/service/protos/connector.proto @@ -355,7 +355,7 @@ message TExpression { // CAST($value AS $type) message TCast { TExpression value = 1; - string type = 2; + Ydb.Type type = 2; } message TNull { diff --git a/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.cpp b/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.cpp index fbdb598591fb..f862bfe45a5b 100644 --- a/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.cpp +++ b/ydb/library/yql/providers/generic/provider/yql_generic_predicate_pushdown.cpp @@ -64,18 +64,61 @@ namespace NYql { bool SerializeExpression(const TExprBase& expression, TExpression* proto, TSerializationContext& ctx, ui64 depth); +#define MATCH_TYPE(DataType, PROTO_TYPE) \ + if (dataSlot == NUdf::EDataSlot::DataType) { \ + dstType->set_type_id(Ydb::Type::PROTO_TYPE); \ + return true; \ + } + bool SerializeCastExpression(const TCoSafeCast& safeCast, TExpression* proto, TSerializationContext& ctx, ui64 depth) { const auto typeAnnotation = safeCast.Type().Ref().GetTypeAnn(); if (!typeAnnotation) { ctx.Err << "expected non empty type annotation for safe cast"; return false; } + if (typeAnnotation->GetKind() != ETypeAnnotationKind::Type) { + ctx.Err << "expected only type expression for safe cast"; + return false; + } auto* dstProto = proto->mutable_cast(); - dstProto->set_type(FormatType(typeAnnotation->Cast()->GetType())); - return SerializeExpression(TExprBase(safeCast.Value()), dstProto->mutable_value(), ctx, depth + 1); + if (!SerializeExpression(safeCast.Value(), dstProto->mutable_value(), ctx, depth + 1)) { + return false; + } + + auto type = typeAnnotation->Cast()->GetType(); + auto* dstType = dstProto->mutable_type(); + if (type->GetKind() == ETypeAnnotationKind::Optional) { + type = type->Cast()->GetItemType(); + dstType = dstType->mutable_optional_type()->mutable_item(); + } + if (type->GetKind() != ETypeAnnotationKind::Data) { + ctx.Err << "expected only data type or optional data type for safe cast"; + return false; + } + const auto dataSlot = type->Cast()->GetSlot(); + + MATCH_TYPE(Bool, BOOL); + MATCH_TYPE(Int8, INT8); + MATCH_TYPE(Int16, INT16); + MATCH_TYPE(Int32, INT32); + MATCH_TYPE(Int64, INT64); + MATCH_TYPE(Uint8, UINT8); + MATCH_TYPE(Uint16, UINT16); + MATCH_TYPE(Uint32, UINT32); + MATCH_TYPE(Uint64, UINT64); + MATCH_TYPE(Float, FLOAT); + MATCH_TYPE(Double, DOUBLE); + MATCH_TYPE(String, STRING); + MATCH_TYPE(Utf8, UTF8); + MATCH_TYPE(Json, JSON); + + ctx.Err << "unknown data slot " << static_cast(dataSlot) << " for safe cast"; + return false; } +#undef MATCH_TYPE + bool SerializeToBytesExpression(const TExprBase& toBytes, TExpression* proto, TSerializationContext& ctx, ui64 depth) { if (toBytes.Ref().ChildrenSize() != 1) { ctx.Err << "invalid ToBytes expression, expected 1 child but got " << toBytes.Ref().ChildrenSize(); @@ -103,7 +146,7 @@ namespace NYql { } auto* dstProto = proto->mutable_cast(); - dstProto->set_type("String"); + dstProto->mutable_type()->set_type_id(Ydb::Type::STRING); return SerializeExpression(toBytexExpr, dstProto->mutable_value(), ctx, depth + 1); } @@ -126,7 +169,7 @@ namespace NYql { ctx.LambdaArgs.insert({lambdaArgs.Ref().Child(0), *dstInput}); return SerializeExpression(lambda.Body(), dstProto->mutable_then_expression(), ctx, depth + 1); } - + #define MATCH_ATOM(AtomType, ATOM_ENUM, proto_name, cpp_type) \ if (auto atom = expression.Maybe()) { \ auto* value = proto->mutable_typed_value(); \ @@ -482,7 +525,53 @@ namespace NYql { case Ydb::Value::kTextValue: return NFq::EncloseAndEscapeString(value.value().text_value(), '"'); default: - throw yexception() << "ErrUnimplementedTypedValue, value case " << static_cast(value.value().value_case()); + throw yexception() << "Failed to format ydb typed vlaue, value case " << static_cast(value.value().value_case()) << " is not supported"; + } + } + + TString FormatPrimitiveType(const Ydb::Type::PrimitiveTypeId& typeId) { + switch (typeId) { + case Ydb::Type::BOOL: + return "Bool"; + case Ydb::Type::INT8: + return "Int8"; + case Ydb::Type::INT16: + return "Int16"; + case Ydb::Type::INT32: + return "Int32"; + case Ydb::Type::INT64: + return "Int64"; + case Ydb::Type::UINT8: + return "Uint8"; + case Ydb::Type::UINT16: + return "Uint16"; + case Ydb::Type::UINT32: + return "Uint32"; + case Ydb::Type::UINT64: + return "Uint64"; + case Ydb::Type::FLOAT: + return "Float"; + case Ydb::Type::DOUBLE: + return "Double"; + case Ydb::Type::STRING: + return "String"; + case Ydb::Type::UTF8: + return "Utf8"; + case Ydb::Type::JSON: + return "Json"; + default: + throw yexception() << "Failed to format primitive type, type case " << static_cast(typeId) << " is not supported"; + } + } + + TString FormatType(const Ydb::Type& type) { + switch (type.type_case()) { + case Ydb::Type::kTypeId: + return FormatPrimitiveType(type.type_id()); + case Ydb::Type::kOptionalType: + return TStringBuilder() << FormatType(type.optional_type().item()) << "?"; + default: + throw yexception() << "Failed to format ydb type, type id " << static_cast(type.type_case()) << " is not supported"; } } @@ -492,7 +581,8 @@ namespace NYql { TString FormatCast(const TExpression::TCast& cast) { auto value = FormatExpression(cast.value()); - return TStringBuilder() << "CAST(" << value << " AS " << cast.type() << ")"; + auto type = FormatType(cast.type()); + return TStringBuilder() << "CAST(" << value << " AS " << type << ")"; } TString FormatExpression(const TExpression& expression) {