Skip to content

Commit

Permalink
Support enum literals in queries
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-ohayon authored and Rongrong Zhong committed Aug 31, 2020
1 parent f99842b commit 3649366
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ private static Object fixValue(TypeSignature signature, Object value)
}
return fixedValue;
}
if (signature.isVarcharEnum()) {
return String.class.cast(value);
}
switch (signature.getBase()) {
case BIGINT:
if (value instanceof String) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.common.type.CharType;
import com.facebook.presto.common.type.DecimalParseResult;
import com.facebook.presto.common.type.Decimals;
import com.facebook.presto.common.type.EnumType;
import com.facebook.presto.common.type.FunctionType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.StandardTypes;
Expand Down Expand Up @@ -58,6 +59,7 @@
import com.facebook.presto.sql.tree.DecimalLiteral;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.DoubleLiteral;
import com.facebook.presto.sql.tree.EnumLiteral;
import com.facebook.presto.sql.tree.ExistsPredicate;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Extract;
Expand Down Expand Up @@ -140,6 +142,7 @@
import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy;
import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoAggregateWindowOrGroupingFunctions;
import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoExternalFunctions;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteralType;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.EXPRESSION_NOT_CONSTANT;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE;
Expand Down Expand Up @@ -432,14 +435,21 @@ protected Type visitDereferenceExpression(DereferenceExpression node, StackableA
{
QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node);

// If this Dereference looks like column reference, try match it to column first.
// Handle qualified name
if (qualifiedName != null) {
// first, try to match it to a column name
Scope scope = context.getContext().getScope();
Optional<ResolvedField> resolvedField = scope.tryResolveField(node, qualifiedName);
if (resolvedField.isPresent()) {
return handleResolvedField(node, resolvedField.get(), context);
}
// otherwise, try to match it to an enum literal (eg Mood.HAPPY)
if (!scope.isColumnReference(qualifiedName)) {
Optional<EnumType> enumType = tryResolveEnumLiteralType(qualifiedName, typeManager);
if (enumType.isPresent()) {
setExpressionType(node.getBase(), enumType.get());
return setExpressionType(node, enumType.get());
}
throw missingAttributeException(node, qualifiedName);
}
}
Expand Down Expand Up @@ -773,6 +783,20 @@ protected Type visitGenericLiteral(GenericLiteral node, StackableAstVisitorConte
return setExpressionType(node, type);
}

@Override
protected Type visitEnumLiteral(EnumLiteral node, StackableAstVisitorContext<Context> context)
{
Type type;
try {
type = typeManager.getType(parseTypeSignature(node.getType()));
}
catch (IllegalArgumentException e) {
throw new SemanticException(TYPE_MISMATCH, node, "Unknown type: " + node.getType());
}

return setExpressionType(node, type);
}

@Override
protected Type visitTimeLiteral(TimeLiteral node, StackableAstVisitorContext<Context> context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,34 @@
*/
package com.facebook.presto.sql.analyzer;

import com.facebook.presto.common.type.EnumType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;

import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.alwaysTrue;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.slice.Slices.utf8Slice;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;

public final class ExpressionTreeUtils
Expand Down Expand Up @@ -112,4 +123,50 @@ public static boolean isEqualComparisonExpression(Expression expression)
{
return expression instanceof ComparisonExpression && ((ComparisonExpression) expression).getOperator() == ComparisonExpression.Operator.EQUAL;
}

static Optional<EnumType> tryResolveEnumLiteralType(QualifiedName qualifiedName, TypeManager typeManager)
{
Optional<QualifiedName> prefix = qualifiedName.getPrefix();
if (!prefix.isPresent()) {
// an enum literal should be of the form `MyEnum.my_key`
return Optional.empty();
}
try {
Type baseType = typeManager.getType(parseTypeSignature(prefix.get().toString()));
if (baseType instanceof EnumType) {
return Optional.of((EnumType) baseType);
}
}
catch (IllegalArgumentException e) {
return Optional.empty();
}
return Optional.empty();
}

private static boolean isEnumLiteral(DereferenceExpression node, Type nodeType)
{
if (!(nodeType instanceof EnumType)) {
return false;
}
QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node);
if (qualifiedName == null) {
return false;
}
Optional<QualifiedName> prefix = qualifiedName.getPrefix();
return prefix.isPresent()
&& prefix.get().toString().equalsIgnoreCase(nodeType.getTypeSignature().getBase());
}

public static Optional<Object> tryResolveEnumLiteral(DereferenceExpression node, Type nodeType)
{
QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node);
if (!isEnumLiteral(node, nodeType)) {
return Optional.empty();
}
EnumType enumType = (EnumType) nodeType;
String enumKey = qualifiedName.getSuffix().toUpperCase(ENGLISH);
checkArgument(enumType.getEnumMap().containsKey(enumKey), format("No key '%s' in enum '%s'", enumKey, nodeType.getDisplayName()));
Object enumValue = enumType.getEnumMap().get(enumKey);
return enumValue instanceof String ? Optional.of(utf8Slice((String) enumValue)) : Optional.of(enumValue);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ private SemanticExceptions() {}

public static SemanticException missingAttributeException(Expression node, QualifiedName name)
{
throw new SemanticException(MISSING_ATTRIBUTE, node, "Column '%s' cannot be resolved", name);
throw new SemanticException(
MISSING_ATTRIBUTE,
node,
name.getPrefix().isPresent() ? "'%s' cannot be resolved" : "Column '%s' cannot be resolved",
name);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
import static com.facebook.presto.sql.analyzer.ConstantExpressionVerifier.verifyExpressionIsConstant;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.createConstantAnalyzer;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteral;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.gen.VarArgsToMapAdapterGenerator.generateVarArgsToMapAdapter;
import static com.facebook.presto.sql.planner.ExpressionDeterminismEvaluator.isDeterministic;
Expand Down Expand Up @@ -298,6 +299,12 @@ public Object visitFieldReference(FieldReference node, Object context)
@Override
protected Object visitDereferenceExpression(DereferenceExpression node, Object context)
{
Type returnType = type(node);
Optional<Object> maybeEnumValue = tryResolveEnumLiteral(node, returnType);
if (maybeEnumValue.isPresent()) {
return maybeEnumValue.get();
}

Type type = type(node.getBase());
// if there is no type for the base of Dereference, it must be QualifiedName
if (type == null) {
Expand All @@ -315,7 +322,6 @@ protected Object visitDereferenceExpression(DereferenceExpression node, Object c
}

RowType rowType = (RowType) type;
Type returnType = type(node);
String fieldName = node.getField().getValue();
List<Field> fields = rowType.getFields();
int index = -1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import com.facebook.presto.sql.tree.CharLiteral;
import com.facebook.presto.sql.tree.DecimalLiteral;
import com.facebook.presto.sql.tree.DoubleLiteral;
import com.facebook.presto.sql.tree.EnumLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.IntervalLiteral;
Expand Down Expand Up @@ -235,6 +236,12 @@ protected Slice visitBinaryLiteral(BinaryLiteral node, ConnectorSession session)
return node.getValue();
}

@Override
protected Object visitEnumLiteral(EnumLiteral node, ConnectorSession context)
{
return node.getValue();
}

@Override
protected Object visitGenericLiteral(GenericLiteral node, ConnectorSession session)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.DefaultTraversalVisitor;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.EnumLiteral;
import com.facebook.presto.sql.tree.Except;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.InPredicate;
Expand Down Expand Up @@ -97,6 +100,7 @@
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isEqualComparisonExpression;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteral;
import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference;
Expand Down Expand Up @@ -670,22 +674,40 @@ protected RelationPlan visitValues(Values node, Void context)
ImmutableList.Builder<RowExpression> values = ImmutableList.builder();
if (row instanceof Row) {
for (Expression item : ((Row) row).getItems()) {
Expression expression = Coercer.addCoercions(item, analysis);
values.add(castToRowExpression(ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(analysis.getParameters(), analysis), expression)));
values.add(rewriteRow(item));
}
}
else {
Expression expression = Coercer.addCoercions(row, analysis);
values.add(castToRowExpression(ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(analysis.getParameters(), analysis), expression)));
values.add(rewriteRow(row));
}

rowsBuilder.add(values.build());
}

ValuesNode valuesNode = new ValuesNode(idAllocator.getNextId(), outputVariablesBuilder.build(), rowsBuilder.build());
return new RelationPlan(valuesNode, scope, outputVariablesBuilder.build());
}

private RowExpression rewriteRow(Expression row)
{
Expression expression = Coercer.addCoercions(row, analysis);
expression = ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(analysis.getParameters(), analysis), expression);

// resolve enum literals
expression = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() {
@Override
public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
Type nodeType = analysis.getType(node);
Optional<Object> maybeEnumValue = tryResolveEnumLiteral(node, nodeType);
if (maybeEnumValue.isPresent()) {
return new EnumLiteral(nodeType.getTypeSignature().toString(), maybeEnumValue.get());
}
return node;
}
}, expression);
return castToRowExpression(expression);
}

@Override
protected RelationPlan visitUnnest(Unnest node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.sql.analyzer.ResolvedField;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.EnumLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
Expand All @@ -36,6 +37,7 @@
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteral;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -234,6 +236,13 @@ public Expression rewriteDereferenceExpression(DereferenceExpression node, Void
// do not rewrite outer references, it will be handled in outer scope planner
return node;
}

Type nodeType = analysis.getType(node);
Optional<Object> maybeEnumValue = tryResolveEnumLiteral(node, nodeType);
if (maybeEnumValue.isPresent()) {
return new EnumLiteral(nodeType.getTypeSignature().toString(), maybeEnumValue.get());
}

return rewriteExpression(node, context, treeRewriter);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.ROW_CONSTRUCTOR;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteral;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
Expand Down Expand Up @@ -562,6 +563,12 @@ private RowExpression buildSwitch(RowExpression operand, List<WhenClause> whenCl
@Override
protected RowExpression visitDereferenceExpression(DereferenceExpression node, Void context)
{
Type returnType = getType(node);
Optional<Object> maybeEnumLiteral = tryResolveEnumLiteral(node, returnType);
if (maybeEnumLiteral.isPresent()) {
return constant(maybeEnumLiteral.get(), returnType);
}

RowType rowType = (RowType) getType(node.getBase());
String fieldName = node.getField().getValue();
List<Field> fields = rowType.getFields();
Expand All @@ -582,7 +589,6 @@ protected RowExpression visitDereferenceExpression(DereferenceExpression node, V
}

checkState(index >= 0, "could not find field name: %s", node.getField());
Type returnType = getType(node);
return specialForm(DEREFERENCE, returnType, process(node.getBase(), context), constant((long) index, INTEGER));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ public void testInvalidAttribute()
assertFails(MISSING_ATTRIBUTE, "SELECT * FROM t1 WHERE f > 1");
}

@Test(expectedExceptions = SemanticException.class, expectedExceptionsMessageRegExp = "line 1:8: Column 't.y' cannot be resolved")
@Test(expectedExceptions = SemanticException.class, expectedExceptionsMessageRegExp = "line 1:8: 't.y' cannot be resolved")
public void testInvalidAttributeCorrectErrorMessage()
{
analyze("SELECT t.y FROM (VALUES 1) t(x)");
Expand Down Expand Up @@ -1165,7 +1165,7 @@ public void testCreateTableAsColumns()
assertFails(MISMATCHED_COLUMN_ALIASES, 1, 19, "CREATE TABLE test(x, y) AS (VALUES 1)");
assertFails(DUPLICATE_COLUMN_NAME, 1, 24, "CREATE TABLE test(abc, AbC) AS SELECT 1, 2");
assertFails(COLUMN_TYPE_UNKNOWN, 1, 1, "CREATE TABLE test(x) AS SELECT null");
assertFails(MISSING_ATTRIBUTE, ".*Column 'y' cannot be resolved", "CREATE TABLE test(x) WITH (p1 = y) AS SELECT null");
assertFails(MISSING_ATTRIBUTE, ".*'y' cannot be resolved", "CREATE TABLE test(x) WITH (p1 = y) AS SELECT null");
assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", "CREATE TABLE test(x) WITH (p1 = 'p1', p2 = 'p2', p1 = 'p3') AS SELECT null");
assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", "CREATE TABLE test(x) WITH (p1 = 'p1', \"p1\" = 'p2') AS SELECT null");
}
Expand All @@ -1176,7 +1176,7 @@ public void testCreateTable()
analyze("CREATE TABLE test (id bigint)");
analyze("CREATE TABLE test (id bigint) WITH (p1 = 'p1')");

assertFails(MISSING_ATTRIBUTE, ".*Column 'y' cannot be resolved", "CREATE TABLE test (x bigint) WITH (p1 = y)");
assertFails(MISSING_ATTRIBUTE, ".*'y' cannot be resolved", "CREATE TABLE test (x bigint) WITH (p1 = y)");
assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", "CREATE TABLE test (id bigint) WITH (p1 = 'p1', p2 = 'p2', p1 = 'p3')");
assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", "CREATE TABLE test (id bigint) WITH (p1 = 'p1', \"p1\" = 'p2')");
}
Expand All @@ -1197,7 +1197,7 @@ public void testCreateSchema()
analyze("CREATE SCHEMA test");
analyze("CREATE SCHEMA test WITH (p1 = 'p1')");

assertFails(MISSING_ATTRIBUTE, ".*Column 'y' cannot be resolved", "CREATE SCHEMA test WITH (p1 = y)");
assertFails(MISSING_ATTRIBUTE, ".*'y' cannot be resolved", "CREATE SCHEMA test WITH (p1 = y)");
assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", "CREATE SCHEMA test WITH (p1 = 'p1', p2 = 'p2', p1 = 'p3')");
assertFails(DUPLICATE_PROPERTY, ".* Duplicate property: p1", "CREATE SCHEMA test WITH (p1 = 'p1', \"p1\" = 'p2')");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void testColumnReferences()
"SELECT t.k FROM " +
"(VALUES (1, 'a')) AS t(k, v1) JOIN" +
"(VALUES (1, 'b')) AS u(k, v2) USING (k)",
".*Column 't.k' cannot be resolved.*");
".*'t.k' cannot be resolved.*");
}

@Test
Expand Down
Loading

0 comments on commit 3649366

Please sign in to comment.