Skip to content

Commit

Permalink
Fixed type passing in proto
Browse files Browse the repository at this point in the history
  • Loading branch information
GrigoriyPA committed Nov 28, 2024
1 parent a7824a5 commit 623f690
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ message TExpression {
// CAST($value AS $type)
message TCast {
TExpression value = 1;
string type = 2;
Ydb.Type type = 2;
}

message TNull {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,61 @@ namespace NYql {

bool SerializeExpression(const TExprBase& expression, TExpression* proto, TSerializationContext& ctx, ui64 depth);

#define MATCH_TYPE(DataType, PROTO_TYPE) \
if (dataSlot == NUdf::EDataSlot::DataType) { \
dstType->set_type_id(Ydb::Type::PROTO_TYPE); \
return true; \
}

bool SerializeCastExpression(const TCoSafeCast& safeCast, TExpression* proto, TSerializationContext& ctx, ui64 depth) {
const auto typeAnnotation = safeCast.Type().Ref().GetTypeAnn();
if (!typeAnnotation) {
ctx.Err << "expected non empty type annotation for safe cast";
return false;
}
if (typeAnnotation->GetKind() != ETypeAnnotationKind::Type) {
ctx.Err << "expected only type expression for safe cast";
return false;
}

auto* dstProto = proto->mutable_cast();
dstProto->set_type(FormatType(typeAnnotation->Cast<TTypeExprType>()->GetType()));
return SerializeExpression(TExprBase(safeCast.Value()), dstProto->mutable_value(), ctx, depth + 1);
if (!SerializeExpression(safeCast.Value(), dstProto->mutable_value(), ctx, depth + 1)) {
return false;
}

auto type = typeAnnotation->Cast<TTypeExprType>()->GetType();
auto* dstType = dstProto->mutable_type();
if (type->GetKind() == ETypeAnnotationKind::Optional) {
type = type->Cast<TOptionalExprType>()->GetItemType();
dstType = dstType->mutable_optional_type()->mutable_item();
}
if (type->GetKind() != ETypeAnnotationKind::Data) {
ctx.Err << "expected only data type or optional data type for safe cast";
return false;
}
const auto dataSlot = type->Cast<TDataExprType>()->GetSlot();

MATCH_TYPE(Bool, BOOL);
MATCH_TYPE(Int8, INT8);
MATCH_TYPE(Int16, INT16);
MATCH_TYPE(Int32, INT32);
MATCH_TYPE(Int64, INT64);
MATCH_TYPE(Uint8, UINT8);
MATCH_TYPE(Uint16, UINT16);
MATCH_TYPE(Uint32, UINT32);
MATCH_TYPE(Uint64, UINT64);
MATCH_TYPE(Float, FLOAT);
MATCH_TYPE(Double, DOUBLE);
MATCH_TYPE(String, STRING);
MATCH_TYPE(Utf8, UTF8);
MATCH_TYPE(Json, JSON);

ctx.Err << "unknown data slot " << static_cast<ui64>(dataSlot) << " for safe cast";
return false;
}

#undef MATCH_TYPE

bool SerializeToBytesExpression(const TExprBase& toBytes, TExpression* proto, TSerializationContext& ctx, ui64 depth) {
if (toBytes.Ref().ChildrenSize() != 1) {
ctx.Err << "invalid ToBytes expression, expected 1 child but got " << toBytes.Ref().ChildrenSize();
Expand Down Expand Up @@ -103,7 +146,7 @@ namespace NYql {
}

auto* dstProto = proto->mutable_cast();
dstProto->set_type("String");
dstProto->mutable_type()->set_type_id(Ydb::Type::STRING);
return SerializeExpression(toBytexExpr, dstProto->mutable_value(), ctx, depth + 1);
}

Expand All @@ -126,7 +169,7 @@ namespace NYql {
ctx.LambdaArgs.insert({lambdaArgs.Ref().Child(0), *dstInput});
return SerializeExpression(lambda.Body(), dstProto->mutable_then_expression(), ctx, depth + 1);
}

#define MATCH_ATOM(AtomType, ATOM_ENUM, proto_name, cpp_type) \
if (auto atom = expression.Maybe<Y_CAT(TCo, AtomType)>()) { \
auto* value = proto->mutable_typed_value(); \
Expand Down Expand Up @@ -482,7 +525,53 @@ namespace NYql {
case Ydb::Value::kTextValue:
return NFq::EncloseAndEscapeString(value.value().text_value(), '"');
default:
throw yexception() << "ErrUnimplementedTypedValue, value case " << static_cast<ui64>(value.value().value_case());
throw yexception() << "Failed to format ydb typed vlaue, value case " << static_cast<ui64>(value.value().value_case()) << " is not supported";
}
}

TString FormatPrimitiveType(const Ydb::Type::PrimitiveTypeId& typeId) {
switch (typeId) {
case Ydb::Type::BOOL:
return "Bool";
case Ydb::Type::INT8:
return "Int8";
case Ydb::Type::INT16:
return "Int16";
case Ydb::Type::INT32:
return "Int32";
case Ydb::Type::INT64:
return "Int64";
case Ydb::Type::UINT8:
return "Uint8";
case Ydb::Type::UINT16:
return "Uint16";
case Ydb::Type::UINT32:
return "Uint32";
case Ydb::Type::UINT64:
return "Uint64";
case Ydb::Type::FLOAT:
return "Float";
case Ydb::Type::DOUBLE:
return "Double";
case Ydb::Type::STRING:
return "String";
case Ydb::Type::UTF8:
return "Utf8";
case Ydb::Type::JSON:
return "Json";
default:
throw yexception() << "Failed to format primitive type, type case " << static_cast<ui64>(typeId) << " is not supported";
}
}

TString FormatType(const Ydb::Type& type) {
switch (type.type_case()) {
case Ydb::Type::kTypeId:
return FormatPrimitiveType(type.type_id());
case Ydb::Type::kOptionalType:
return TStringBuilder() << FormatType(type.optional_type().item()) << "?";
default:
throw yexception() << "Failed to format ydb type, type id " << static_cast<ui64>(type.type_case()) << " is not supported";
}
}

Expand All @@ -492,7 +581,8 @@ namespace NYql {

TString FormatCast(const TExpression::TCast& cast) {
auto value = FormatExpression(cast.value());
return TStringBuilder() << "CAST(" << value << " AS " << cast.type() << ")";
auto type = FormatType(cast.type());
return TStringBuilder() << "CAST(" << value << " AS " << type << ")";
}

TString FormatExpression(const TExpression& expression) {
Expand Down

0 comments on commit 623f690

Please sign in to comment.