Skip to content

Commit

Permalink
[native] Refactor TypeSignatureTypeConverter and test
Browse files Browse the repository at this point in the history
  • Loading branch information
mbasmanova committed Oct 13, 2023
1 parent ca11fd6 commit 0e52fc3
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 462 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"

using namespace facebook::velox;

namespace facebook::presto {

TypePtr parseTypeSignature(const std::string& signature) {
Expand All @@ -32,9 +33,9 @@ TypePtr parseTypeSignature(const std::string& signature) {
// static
TypePtr TypeSignatureTypeConverter::parse(const std::string& text) {
antlr4::ANTLRInputStream input(text);
TypeSignatureLexer lexer(&input);
type::TypeSignatureLexer lexer(&input);
antlr4::CommonTokenStream tokens(&lexer);
TypeSignatureParser parser(&tokens);
type::TypeSignatureParser parser(&tokens);

parser.setErrorHandler(std::make_shared<antlr4::BailErrorStrategy>());

Expand All @@ -47,14 +48,63 @@ TypePtr TypeSignatureTypeConverter::parse(const std::string& text) {
}
}

namespace {

TypePtr typeFromString(const std::string& typeName) {
auto upper = boost::to_upper_copy(typeName);

if (upper == "UNKNOWN") {
return UNKNOWN();
}

if (upper == TIMESTAMP_WITH_TIME_ZONE()->toString()) {
return TIMESTAMP_WITH_TIME_ZONE();
}

if (upper == HYPERLOGLOG()->toString()) {
return HYPERLOGLOG();
}

if (upper == JSON()->toString()) {
return JSON();
}

if (upper == "INT") {
upper = "INTEGER";
} else if (upper == "DOUBLE PRECISION") {
upper = "DOUBLE";
}

if (upper == INTERVAL_DAY_TIME()->toString()) {
return INTERVAL_DAY_TIME();
}

if (upper == INTERVAL_YEAR_MONTH()->toString()) {
return INTERVAL_YEAR_MONTH();
}

if (upper == DATE()->toString()) {
return DATE();
}

return createScalarType(mapNameToTypeKind(upper));
}

struct NamedType {
std::string name;
velox::TypePtr type;
};

} // namespace

antlrcpp::Any TypeSignatureTypeConverter::visitStart(
TypeSignatureParser::StartContext* ctx) {
type::TypeSignatureParser::StartContext* ctx) {
NamedType named = visit(ctx->type_spec()).as<NamedType>();
return named.type;
}

antlrcpp::Any TypeSignatureTypeConverter::visitType_spec(
TypeSignatureParser::Type_specContext* ctx) {
type::TypeSignatureParser::Type_specContext* ctx) {
if (ctx->named_type()) {
return visit(ctx->named_type());
} else {
Expand All @@ -63,40 +113,39 @@ antlrcpp::Any TypeSignatureTypeConverter::visitType_spec(
}

antlrcpp::Any TypeSignatureTypeConverter::visitNamed_type(
TypeSignatureParser::Named_typeContext* ctx) {
type::TypeSignatureParser::Named_typeContext* ctx) {
return NamedType{
visit(ctx->identifier()).as<std::string>(),
visit(ctx->type()).as<TypePtr>()};
}

antlrcpp::Any TypeSignatureTypeConverter::visitType(
TypeSignatureParser::TypeContext* ctx) {
type::TypeSignatureParser::TypeContext* ctx) {
return visitChildren(ctx);
}

antlrcpp::Any TypeSignatureTypeConverter::visitSimple_type(
TypeSignatureParser::Simple_typeContext* ctx) {
type::TypeSignatureParser::Simple_typeContext* ctx) {
return ctx->WORD() ? typeFromString(ctx->WORD()->getText())
: typeFromString(ctx->TYPE_WITH_SPACES()->getText());
}

antlrcpp::Any TypeSignatureTypeConverter::visitDecimal_type(
TypeSignatureParser::Decimal_typeContext* ctx) {
if (ctx->NUMBER().size() != 2) {
VELOX_USER_FAIL("Invalid decimal type");
}
auto precision = ctx->NUMBER(0)->getText();
auto scale = ctx->NUMBER(1)->getText();
type::TypeSignatureParser::Decimal_typeContext* ctx) {
VELOX_USER_CHECK_EQ(2, ctx->NUMBER().size(), "Invalid decimal type");

const auto precision = ctx->NUMBER(0)->getText();
const auto scale = ctx->NUMBER(1)->getText();
return DECIMAL(std::atoi(precision.c_str()), std::atoi(scale.c_str()));
}

antlrcpp::Any TypeSignatureTypeConverter::visitVariable_type(
TypeSignatureParser::Variable_typeContext* ctx) {
type::TypeSignatureParser::Variable_typeContext* ctx) {
return typeFromString(ctx->WORD()->getText());
}

antlrcpp::Any TypeSignatureTypeConverter::visitType_list(
TypeSignatureParser::Type_listContext* ctx) {
type::TypeSignatureParser::Type_listContext* ctx) {
std::vector<NamedType> types;
for (auto type_spec : ctx->type_spec()) {
types.emplace_back(visit(type_spec).as<NamedType>());
Expand All @@ -105,24 +154,38 @@ antlrcpp::Any TypeSignatureTypeConverter::visitType_list(
}

antlrcpp::Any TypeSignatureTypeConverter::visitRow_type(
TypeSignatureParser::Row_typeContext* ctx) {
return rowFromNamedTypes(
visit(ctx->type_list()).as<std::vector<NamedType>>());
type::TypeSignatureParser::Row_typeContext* ctx) {
const auto namedTypes = visit(ctx->type_list()).as<std::vector<NamedType>>();

std::vector<std::string> names;
std::vector<TypePtr> types;
names.reserve(namedTypes.size());
types.reserve(namedTypes.size());
for (const auto& namedType : namedTypes) {
names.push_back(namedType.name);
types.push_back(namedType.type);
}

const TypePtr rowType = ROW(std::move(names), std::move(types));
return rowType;
}

antlrcpp::Any TypeSignatureTypeConverter::visitMap_type(
TypeSignatureParser::Map_typeContext* ctx) {
return mapFromKeyValueType(
visit(ctx->type()[0]).as<TypePtr>(), visit(ctx->type()[1]).as<TypePtr>());
type::TypeSignatureParser::Map_typeContext* ctx) {
const auto keyType = visit(ctx->type()[0]).as<TypePtr>();
const auto valueType = visit(ctx->type()[1]).as<TypePtr>();
const TypePtr mapType = MAP(keyType, valueType);
return mapType;
}

antlrcpp::Any TypeSignatureTypeConverter::visitArray_type(
TypeSignatureParser::Array_typeContext* ctx) {
return arrayFromType(visit(ctx->type()).as<TypePtr>());
type::TypeSignatureParser::Array_typeContext* ctx) {
const TypePtr arrayType = ARRAY(visit(ctx->type()).as<TypePtr>());
return arrayType;
}

antlrcpp::Any TypeSignatureTypeConverter::visitFunction_type(
TypeSignatureParser::Function_typeContext* ctx) {
type::TypeSignatureParser::Function_typeContext* ctx) {
const auto numArgs = ctx->type().size() - 1;

std::vector<TypePtr> argumentTypes;
Expand All @@ -138,7 +201,7 @@ antlrcpp::Any TypeSignatureTypeConverter::visitFunction_type(
}

antlrcpp::Any TypeSignatureTypeConverter::visitIdentifier(
TypeSignatureParser::IdentifierContext* ctx) {
type::TypeSignatureParser::IdentifierContext* ctx) {
if (ctx->WORD()) {
return ctx->WORD()->getText();
} else {
Expand All @@ -147,67 +210,4 @@ antlrcpp::Any TypeSignatureTypeConverter::visitIdentifier(
}
}

TypePtr typeFromString(const std::string& typeName) {
auto upper = boost::to_upper_copy(typeName);

if (upper == "UNKNOWN") {
return UNKNOWN();
}

if (upper == TIMESTAMP_WITH_TIME_ZONE()->toString()) {
return TIMESTAMP_WITH_TIME_ZONE();
}

if (upper == HYPERLOGLOG()->toString()) {
return HYPERLOGLOG();
}

if (upper == JSON()->toString()) {
return JSON();
}

if (upper == "INT") {
upper = "INTEGER";
} else if (upper == "DOUBLE PRECISION") {
upper = "DOUBLE";
}

if (upper == INTERVAL_DAY_TIME()->toString()) {
return INTERVAL_DAY_TIME();
}

if (upper == INTERVAL_YEAR_MONTH()->toString()) {
return INTERVAL_YEAR_MONTH();
}

if (upper == DATE()->toString()) {
return DATE();
}

return createScalarType(mapNameToTypeKind(upper));
}

TypePtr rowFromNamedTypes(const std::vector<NamedType>& named) {
std::vector<std::string> names{};
std::transform(
named.begin(), named.end(), std::back_inserter(names), [](NamedType v) {
return v.name;
});
std::vector<TypePtr> types{};
std::transform(
named.begin(), named.end(), std::back_inserter(types), [](NamedType v) {
return v.type;
});

return TypeFactory<TypeKind::ROW>::create(std::move(names), std::move(types));
}

TypePtr mapFromKeyValueType(TypePtr keyType, TypePtr valueType) {
return TypeFactory<TypeKind::MAP>::create(keyType, valueType);
}

TypePtr arrayFromType(TypePtr valueType) {
return TypeFactory<TypeKind::ARRAY>::create(valueType);
}

} // namespace facebook::presto
Original file line number Diff line number Diff line change
Expand Up @@ -18,53 +18,39 @@

#include "presto_cpp/main/types/antlr/TypeSignatureBaseVisitor.h"

namespace facebook {
namespace presto {
using namespace type;
namespace facebook::presto {

class TypeSignatureTypeConverter : TypeSignatureBaseVisitor {
class TypeSignatureTypeConverter : type::TypeSignatureBaseVisitor {
public:
static velox::TypePtr parse(const std::string& text);

private:
virtual antlrcpp::Any visitStart(
TypeSignatureParser::StartContext* ctx) override;
type::TypeSignatureParser::StartContext* ctx) override;
virtual antlrcpp::Any visitNamed_type(
TypeSignatureParser::Named_typeContext* ctx) override;
type::TypeSignatureParser::Named_typeContext* ctx) override;
virtual antlrcpp::Any visitType_spec(
TypeSignatureParser::Type_specContext* ctx) override;
type::TypeSignatureParser::Type_specContext* ctx) override;
virtual antlrcpp::Any visitType(
TypeSignatureParser::TypeContext* ctx) override;
type::TypeSignatureParser::TypeContext* ctx) override;
virtual antlrcpp::Any visitSimple_type(
TypeSignatureParser::Simple_typeContext* ctx) override;
type::TypeSignatureParser::Simple_typeContext* ctx) override;
virtual antlrcpp::Any visitDecimal_type(
TypeSignatureParser::Decimal_typeContext* ctx) override;
type::TypeSignatureParser::Decimal_typeContext* ctx) override;
virtual antlrcpp::Any visitVariable_type(
TypeSignatureParser::Variable_typeContext* ctx) override;
type::TypeSignatureParser::Variable_typeContext* ctx) override;
virtual antlrcpp::Any visitType_list(
TypeSignatureParser::Type_listContext* ctx) override;
type::TypeSignatureParser::Type_listContext* ctx) override;
virtual antlrcpp::Any visitRow_type(
TypeSignatureParser::Row_typeContext* ctx) override;
type::TypeSignatureParser::Row_typeContext* ctx) override;
virtual antlrcpp::Any visitMap_type(
TypeSignatureParser::Map_typeContext* ctx) override;
type::TypeSignatureParser::Map_typeContext* ctx) override;
virtual antlrcpp::Any visitArray_type(
TypeSignatureParser::Array_typeContext* ctx) override;
type::TypeSignatureParser::Array_typeContext* ctx) override;
virtual antlrcpp::Any visitFunction_type(
TypeSignatureParser::Function_typeContext* ctx) override;
type::TypeSignatureParser::Function_typeContext* ctx) override;
virtual antlrcpp::Any visitIdentifier(
TypeSignatureParser::IdentifierContext* ctx) override;

public:
static std::shared_ptr<const velox::Type> parse(const std::string& text);
type::TypeSignatureParser::IdentifierContext* ctx) override;
};

struct NamedType {
std::string name;
velox::TypePtr type;
};

velox::TypePtr typeFromString(const std::string& typeName);
velox::TypePtr rowFromNamedTypes(const std::vector<NamedType>& named);
velox::TypePtr mapFromKeyValueType(
velox::TypePtr keyType,
velox::TypePtr valueType);
velox::TypePtr arrayFromType(velox::TypePtr valueType);

} // namespace presto
} // namespace facebook
} // namespace facebook::presto
Loading

0 comments on commit 0e52fc3

Please sign in to comment.