diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionNamespaceManager.java index 6f095c809f2c6..e9e80fe57b92f 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionNamespaceManager.java @@ -191,6 +191,7 @@ import com.facebook.presto.type.DateTimeOperators; import com.facebook.presto.type.DecimalOperators; import com.facebook.presto.type.DoubleOperators; +import com.facebook.presto.type.EnumCasts; import com.facebook.presto.type.HyperLogLogOperators; import com.facebook.presto.type.IntegerOperators; import com.facebook.presto.type.IntervalDayTimeOperators; @@ -198,6 +199,7 @@ import com.facebook.presto.type.IpAddressOperators; import com.facebook.presto.type.IpPrefixOperators; import com.facebook.presto.type.LikeFunctions; +import com.facebook.presto.type.LongEnumOperators; import com.facebook.presto.type.QuantileDigestOperators; import com.facebook.presto.type.RealOperators; import com.facebook.presto.type.SmallintOperators; @@ -209,6 +211,7 @@ import com.facebook.presto.type.TinyintOperators; import com.facebook.presto.type.UnknownOperators; import com.facebook.presto.type.VarbinaryOperators; +import com.facebook.presto.type.VarcharEnumOperators; import com.facebook.presto.type.VarcharOperators; import com.facebook.presto.type.khyperloglog.KHyperLogLogAggregationFunction; import com.facebook.presto.type.khyperloglog.KHyperLogLogFunctions; @@ -705,7 +708,10 @@ public BuiltInFunctionNamespaceManager( .function(MergeTDigestFunction.MERGE) .sqlInvokedScalar(MapNormalizeFunction.class) .sqlInvokedScalars(ArrayArithmeticFunctions.class) - .scalar(DynamicFilterPlaceholderFunction.class); + .scalar(DynamicFilterPlaceholderFunction.class) + .scalars(EnumCasts.class) + .scalars(LongEnumOperators.class) + .scalars(VarcharEnumOperators.class); switch (featuresConfig.getRegexLibrary()) { case JONI: diff --git a/presto-main/src/main/java/com/facebook/presto/type/EnumCasts.java b/presto-main/src/main/java/com/facebook/presto/type/EnumCasts.java new file mode 100644 index 0000000000000..61ad30b329914 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/type/EnumCasts.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.type; + +import com.facebook.presto.common.type.LongEnumType; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharEnumType; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.ScalarOperator; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import io.airlift.slice.Slice; + +import static com.facebook.presto.common.function.OperatorType.CAST; +import static com.facebook.presto.common.type.StandardTypes.BIGINT; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; + +public final class EnumCasts +{ + private EnumCasts() + { + } + + @ScalarOperator(CAST) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType("T") + public static Slice castVarcharToEnum(@TypeParameter("T") Type enumType, @SqlType(StandardTypes.VARCHAR) Slice value) + { + if (!(((VarcharEnumType) enumType).getEnumMap().values().contains(value.toStringUtf8()))) { + throw new PrestoException(INVALID_CAST_ARGUMENT, + String.format( + "No value '%s' in enum '%s'", + value.toStringUtf8(), + enumType.getTypeSignature().getBase())); + } + return value; + } + + @ScalarOperator(CAST) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType(StandardTypes.VARCHAR) + public static Slice castEnumToVarchar(@SqlType("T") Slice value) + { + return value; + } + + @ScalarOperator(CAST) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType("T") + public static long castBigintToEnum(@TypeParameter("T") Type enumType, @SqlType(BIGINT) long value) + { + return castLongToEnum(enumType, value); + } + + @ScalarOperator(CAST) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType("T") + public static long castIntegerToEnum(@TypeParameter("T") Type enumType, @SqlType(StandardTypes.INTEGER) long value) + { + return castLongToEnum(enumType, value); + } + + @ScalarOperator(CAST) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType("T") + public static long castSmallintToEnum(@TypeParameter("T") Type enumType, @SqlType(StandardTypes.SMALLINT) long value) + { + return castLongToEnum(enumType, value); + } + + @ScalarOperator(CAST) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType("T") + public static long castTinyintToEnum(@TypeParameter("T") Type enumType, @SqlType(StandardTypes.TINYINT) long value) + { + return castLongToEnum(enumType, value); + } + + private static long castLongToEnum(Type enumType, long value) + { + if (!((LongEnumType) enumType).getEnumMap().values().contains(value)) { + throw new PrestoException(INVALID_CAST_ARGUMENT, + String.format( + "No value '%d' in enum '%s'", + value, + enumType.getTypeSignature().getBase())); + } + return value; + } + + @ScalarOperator(CAST) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType(BIGINT) + public static long castEnumToBigint(@SqlType("T") long value) + { + return value; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/type/LongEnumOperators.java b/presto-main/src/main/java/com/facebook/presto/type/LongEnumOperators.java new file mode 100644 index 0000000000000..6f0391b66fb1d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/type/LongEnumOperators.java @@ -0,0 +1,143 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.type; + +import com.facebook.presto.common.type.AbstractLongType; +import com.facebook.presto.common.type.LongEnumType; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.IsNull; +import com.facebook.presto.spi.function.ScalarOperator; +import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import io.airlift.slice.XxHash64; + +import static com.facebook.presto.common.function.OperatorType.BETWEEN; +import static com.facebook.presto.common.function.OperatorType.EQUAL; +import static com.facebook.presto.common.function.OperatorType.GREATER_THAN; +import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL; +import static com.facebook.presto.common.function.OperatorType.HASH_CODE; +import static com.facebook.presto.common.function.OperatorType.INDETERMINATE; +import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM; +import static com.facebook.presto.common.function.OperatorType.LESS_THAN; +import static com.facebook.presto.common.function.OperatorType.LESS_THAN_OR_EQUAL; +import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL; +import static com.facebook.presto.common.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.common.type.StandardTypes.BIGINT; +import static com.facebook.presto.common.type.StandardTypes.BOOLEAN; + +public final class LongEnumOperators +{ + private LongEnumOperators() {} + + @ScalarOperator(EQUAL) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType(BOOLEAN) + @SqlNullable + public static Boolean equal(@SqlType("T") long left, @SqlType("T") long right) + { + return left == right; + } + + @ScalarOperator(NOT_EQUAL) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType(BOOLEAN) + @SqlNullable + public static Boolean notEqual(@SqlType("T") long left, @SqlType("T") long right) + { + return left != right; + } + + @ScalarOperator(IS_DISTINCT_FROM) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType(BOOLEAN) + public static boolean isDistinctFrom( + @SqlType("T") long left, + @IsNull boolean leftNull, + @SqlType("T") long right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); + } + + @ScalarOperator(HASH_CODE) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType(BIGINT) + public static long hashCode(@SqlType("T") long value) + { + return AbstractLongType.hash(value); + } + + @ScalarOperator(XX_HASH_64) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType(BIGINT) + public static long xxHash64(@SqlType("T") long value) + { + return XxHash64.hash(value); + } + + @ScalarOperator(INDETERMINATE) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType(BOOLEAN) + public static boolean indeterminate(@SqlType("T") long value, @IsNull boolean isNull) + { + return isNull; + } + + @ScalarOperator(LESS_THAN) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType(StandardTypes.BOOLEAN) + public static boolean lessThan(@SqlType("T") long left, @SqlType("T") long right) + { + return left < right; + } + + @ScalarOperator(LESS_THAN_OR_EQUAL) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType(StandardTypes.BOOLEAN) + public static boolean lessThanOrEqual(@SqlType("T") long left, @SqlType("T") long right) + { + return left <= right; + } + + @ScalarOperator(GREATER_THAN) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType(StandardTypes.BOOLEAN) + public static boolean greaterThan(@SqlType("T") long left, @SqlType("T") long right) + { + return left > right; + } + + @ScalarOperator(GREATER_THAN_OR_EQUAL) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType(StandardTypes.BOOLEAN) + public static boolean greaterThanOrEqual(@SqlType("T") long left, @SqlType("T") long right) + { + return left >= right; + } + + @ScalarOperator(BETWEEN) + @TypeParameter(value = "T", boundedBy = LongEnumType.class) + @SqlType(StandardTypes.BOOLEAN) + public static boolean between(@SqlType("T") long value, @SqlType("T") long min, @SqlType("T") long max) + { + return min <= value && value <= max; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/type/VarcharEnumOperators.java b/presto-main/src/main/java/com/facebook/presto/type/VarcharEnumOperators.java new file mode 100644 index 0000000000000..9a1ef667bedb2 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/type/VarcharEnumOperators.java @@ -0,0 +1,143 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.type; + +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.VarcharEnumType; +import com.facebook.presto.spi.function.IsNull; +import com.facebook.presto.spi.function.ScalarOperator; +import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import io.airlift.slice.Slice; +import io.airlift.slice.XxHash64; + +import static com.facebook.presto.common.function.OperatorType.BETWEEN; +import static com.facebook.presto.common.function.OperatorType.EQUAL; +import static com.facebook.presto.common.function.OperatorType.GREATER_THAN; +import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL; +import static com.facebook.presto.common.function.OperatorType.HASH_CODE; +import static com.facebook.presto.common.function.OperatorType.INDETERMINATE; +import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM; +import static com.facebook.presto.common.function.OperatorType.LESS_THAN; +import static com.facebook.presto.common.function.OperatorType.LESS_THAN_OR_EQUAL; +import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL; +import static com.facebook.presto.common.function.OperatorType.XX_HASH_64; +import static com.facebook.presto.common.type.StandardTypes.BIGINT; +import static com.facebook.presto.common.type.StandardTypes.BOOLEAN; + +public final class VarcharEnumOperators +{ + private VarcharEnumOperators() {} + + @ScalarOperator(EQUAL) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType(BOOLEAN) + @SqlNullable + public static Boolean equal(@SqlType("T") Slice left, @SqlType("T") Slice right) + { + return left.equals(right); + } + + @ScalarOperator(NOT_EQUAL) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType(BOOLEAN) + @SqlNullable + public static Boolean notEqual(@SqlType("T") Slice left, @SqlType("T") Slice right) + { + return !left.equals(right); + } + + @ScalarOperator(IS_DISTINCT_FROM) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType(BOOLEAN) + public static boolean isDistinctFrom( + @SqlType("T") Slice left, + @IsNull boolean leftNull, + @SqlType("T") Slice right, + @IsNull boolean rightNull) + { + if (leftNull != rightNull) { + return true; + } + if (leftNull) { + return false; + } + return notEqual(left, right); + } + + @ScalarOperator(HASH_CODE) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType(BIGINT) + public static long hashCode(@SqlType("T") Slice value) + { + return xxHash64(value); + } + + @ScalarOperator(XX_HASH_64) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType(BIGINT) + public static long xxHash64(@SqlType("T") Slice value) + { + return XxHash64.hash(value); + } + + @ScalarOperator(INDETERMINATE) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType(BOOLEAN) + public static boolean indeterminate(@SqlType("T") Slice value, @IsNull boolean isNull) + { + return isNull; + } + + @ScalarOperator(LESS_THAN) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType(StandardTypes.BOOLEAN) + public static boolean lessThan(@SqlType("T") Slice left, @SqlType("T") Slice right) + { + return left.compareTo(right) < 0; + } + + @ScalarOperator(LESS_THAN_OR_EQUAL) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType(StandardTypes.BOOLEAN) + public static boolean lessThanOrEqual(@SqlType("T") Slice left, @SqlType("T") Slice right) + { + return left.compareTo(right) <= 0; + } + + @ScalarOperator(GREATER_THAN) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType(StandardTypes.BOOLEAN) + public static boolean greaterThan(@SqlType("T") Slice left, @SqlType("T") Slice right) + { + return left.compareTo(right) > 0; + } + + @ScalarOperator(GREATER_THAN_OR_EQUAL) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType(StandardTypes.BOOLEAN) + public static boolean greaterThanOrEqual(@SqlType("T") Slice left, @SqlType("T") Slice right) + { + return left.compareTo(right) >= 0; + } + + @ScalarOperator(BETWEEN) + @TypeParameter(value = "T", boundedBy = VarcharEnumType.class) + @SqlType(StandardTypes.BOOLEAN) + public static boolean between(@SqlType("T") Slice value, @SqlType("T") Slice min, @SqlType("T") Slice max) + { + return min.compareTo(value) <= 0 && value.compareTo(max) <= 0; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/util/JsonUtil.java b/presto-main/src/main/java/com/facebook/presto/util/JsonUtil.java index 976f2cc73a916..2b4d635adcb65 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/JsonUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/util/JsonUtil.java @@ -20,11 +20,13 @@ import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.DecimalType; import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.EnumType; import com.facebook.presto.common.type.MapType; import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.RowType.Field; import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.type.BigintOperators; import com.facebook.presto.type.BooleanOperators; @@ -148,6 +150,9 @@ public static boolean canCastToJson(Type type) baseType.equals(StandardTypes.DATE)) { return true; } + if (type instanceof EnumType) { + return true; + } if (type instanceof ArrayType) { return canCastToJson(((ArrayType) type).getElementType()); } @@ -165,7 +170,11 @@ public static boolean canCastToJson(Type type) public static boolean canCastFromJson(Type type) { - String baseType = type.getTypeSignature().getBase(); + TypeSignature signature = type.getTypeSignature(); + String baseType = signature.getBase(); + if (signature.isEnum()) { + return true; + } if (baseType.equals(StandardTypes.BOOLEAN) || baseType.equals(StandardTypes.TINYINT) || baseType.equals(StandardTypes.SMALLINT) || @@ -201,7 +210,8 @@ private static boolean isValidJsonObjectKeyType(Type type) baseType.equals(StandardTypes.REAL) || baseType.equals(StandardTypes.DOUBLE) || baseType.equals(StandardTypes.DECIMAL) || - baseType.equals(StandardTypes.VARCHAR); + baseType.equals(StandardTypes.VARCHAR) || + type.getTypeSignature().isEnum(); } // transform the map key into string for use as JSON object key @@ -211,7 +221,14 @@ public interface ObjectKeyProvider static ObjectKeyProvider createObjectKeyProvider(Type type) { - String baseType = type.getTypeSignature().getBase(); + TypeSignature signature = type.getTypeSignature(); + String baseType = signature.getBase(); + if (signature.isLongEnum()) { + return (block, position) -> String.valueOf(type.getLong(block, position)); + } + if (signature.isVarcharEnum()) { + return (block, position) -> type.getSlice(block, position).toStringUtf8(); + } switch (baseType) { case UnknownType.NAME: return (block, position) -> null; @@ -253,7 +270,14 @@ void writeJsonValue(JsonGenerator jsonGenerator, Block block, int position, SqlF static JsonGeneratorWriter createJsonGeneratorWriter(Type type) { - String baseType = type.getTypeSignature().getBase(); + TypeSignature signature = type.getTypeSignature(); + String baseType = signature.getBase(); + if (signature.isLongEnum()) { + return new LongJsonGeneratorWriter(type); + } + if (signature.isVarcharEnum()) { + return new VarcharJsonGeneratorWriter(type); + } switch (baseType) { case UnknownType.NAME: return new UnknownJsonGeneratorWriter(); @@ -864,7 +888,14 @@ void append(JsonParser parser, BlockBuilder blockBuilder) static BlockBuilderAppender createBlockBuilderAppender(Type type) { - String baseType = type.getTypeSignature().getBase(); + TypeSignature signature = type.getTypeSignature(); + String baseType = signature.getBase(); + if (signature.isLongEnum()) { + return new BigintBlockBuilderAppender(); + } + if (signature.isVarcharEnum()) { + return new VarcharBlockBuilderAppender(type); + } switch (baseType) { case StandardTypes.BOOLEAN: return new BooleanBlockBuilderAppender(); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/TestingPrestoClient.java b/presto-tests/src/main/java/com/facebook/presto/tests/TestingPrestoClient.java index 7b4029471be6c..6c9590d674934 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/TestingPrestoClient.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/TestingPrestoClient.java @@ -20,11 +20,14 @@ import com.facebook.presto.client.QueryStatusInfo; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.JsonType; +import com.facebook.presto.common.type.LongEnumType; import com.facebook.presto.common.type.MapType; import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.SqlTimestamp; import com.facebook.presto.common.type.SqlTimestampWithTimeZone; import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharEnumType; import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.server.testing.TestingPrestoServer; import com.facebook.presto.spi.PrestoWarning; @@ -254,6 +257,15 @@ else if (type instanceof RowType) { else if (type instanceof DecimalType) { return new BigDecimal((String) value); } + else if (type instanceof JsonType) { + return value; + } + else if (type instanceof VarcharEnumType) { + return value; + } + else if (type instanceof LongEnumType) { + return ((Number) value).longValue(); + } else if (type.getTypeSignature().getBase().equals("ObjectId")) { return value; } diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestEnums.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestEnums.java new file mode 100644 index 0000000000000..177455ea22697 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestEnums.java @@ -0,0 +1,265 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.facebook.presto.Session; +import com.facebook.presto.common.type.LongEnumType.LongEnumMap; +import com.facebook.presto.common.type.ParametricType; +import com.facebook.presto.common.type.VarcharEnumType.VarcharEnumMap; +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.MaterializedRow; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.type.LongEnumParametricType; +import com.facebook.presto.type.VarcharEnumParametricType; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.stream.Collectors; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.util.Collections.singletonList; +import static org.testng.Assert.assertEquals; + +@Test(singleThreaded = true) +public class TestEnums + extends AbstractTestQueryFramework +{ + private static final Long BIG_VALUE = Integer.MAX_VALUE + 10L; // 2147483657 + + private static final LongEnumParametricType MOOD_ENUM = new LongEnumParametricType("Mood", new LongEnumMap(ImmutableMap.of( + "HAPPY", 0L, + "SAD", 1L, + "MELLOW", BIG_VALUE, + "curious", -2L))); + private static final VarcharEnumParametricType COUNTRY_ENUM = new VarcharEnumParametricType("Country", new VarcharEnumMap(ImmutableMap.of( + "US", "United States", + "BAHAMAS", "The Bahamas", + "FRANCE", "France", + "CHINA", "中国", + "भारत", "India"))); + private static final VarcharEnumParametricType TEST_ENUM = new VarcharEnumParametricType("TestEnum", new VarcharEnumMap(ImmutableMap.of( + "TEST", "\"}\"", + "TEST2", "", + "TEST3", " ", + "TEST4", ")))\"\""))); + + static class TestEnumPlugin + implements Plugin + { + @Override + public Iterable getParametricTypes() + { + return ImmutableList.of(MOOD_ENUM, COUNTRY_ENUM, TEST_ENUM); + } + } + + protected TestEnums() + { + super(TestEnums::createQueryRunner); + } + + private static QueryRunner createQueryRunner() + { + try { + Session session = testSessionBuilder().build(); + QueryRunner queryRunner = DistributedQueryRunner.builder(session).setNodeCount(1).build(); + queryRunner.installPlugin(new TestEnumPlugin()); + return queryRunner; + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + private void assertQueryResultUnordered(@Language("SQL") String query, List> expectedRows) + { + MaterializedResult rows = computeActual(query); + assertEquals( + ImmutableSet.copyOf(rows.getMaterializedRows()), + expectedRows.stream().map(row -> new MaterializedRow(1, row)).collect(Collectors.toSet())); + } + + private void assertSingleValue(@Language("SQL") String expression, Object expectedResult) + { + assertQueryResultUnordered("SELECT " + expression, singletonList(singletonList(expectedResult))); + } + + @Test + public void testEnumLiterals() + { + assertQueryResultUnordered( + "SELECT Mood.HAPPY, mood.happY, \"mood\".SAD, \"mood\".\"mellow\"", + singletonList(ImmutableList.of(0L, 0L, 1L, BIG_VALUE))); + + assertQueryResultUnordered( + "SELECT Country.us, country.\"CHINA\", Country.\"भारत\"", + singletonList(ImmutableList.of("United States", "中国", "India"))); + + assertQueryResultUnordered( + "SELECT testEnum.TEST, testEnum.TEST2, testEnum.TEST3, array[testEnum.TEST4]", + singletonList(ImmutableList.of("\"}\"", "", " ", ImmutableList.of(")))\"\"")))); + + assertQueryFails("SELECT mood.hello", ".*No key 'HELLO' in enum 'Mood'"); + } + + @Test + public void testEnumCasts() + { + assertSingleValue("CAST(CAST(1 AS TINYINT) AS Mood)", 1L); + assertSingleValue("CAST('The Bahamas' AS COUNTRY)", "The Bahamas"); + assertSingleValue("CAST(row(1, 1) as row(x BIGINT, y Mood))", ImmutableList.of(1L, 1L)); + assertSingleValue("CAST(mood.MELLOW AS BIGINT)", BIG_VALUE); + assertSingleValue( + "cast(map(array[country.FRANCE], array[array[mood.HAPPY]]) as JSON)", + "{\"France\":[0]}"); + assertSingleValue( + "map_filter(MAP(ARRAY[country.FRANCE, country.US], ARRAY[mood.HAPPY, mood.SAD]), (k,v) -> CAST(v AS BIGINT) > 0)", + ImmutableMap.of("United States", 1L)); + assertSingleValue( + "cast(JSON '{\"France\": [0]}' as MAP>)", + ImmutableMap.of("France", singletonList(0L))); + assertQueryFails("select cast(7 as mood)", ".*No value '7' in enum 'Mood'"); + } + + @Test + public void testVarcharEnumComparisonOperators() + { + assertSingleValue("country.US = CAST('United States' AS country)", true); + assertSingleValue("country.FRANCE = country.BAHAMAS", false); + + assertSingleValue("country.FRANCE != country.US", true); + assertSingleValue("array[country.FRANCE, country.BAHAMAS] != array[country.US, country.BAHAMAS]", true); + + assertSingleValue("country.CHINA IN (country.US, null, country.BAHAMAS, country.China)", true); + assertSingleValue("country.BAHAMAS IN (country.US, country.FRANCE)", false); + + assertSingleValue("country.BAHAMAS < country.US", true); + assertSingleValue("country.BAHAMAS < country.BAHAMAS", false); + + assertSingleValue("country.\"भारत\" <= country.\"भारत\"", true); + assertSingleValue("country.\"भारत\" <= country.FRANCE", false); + + assertSingleValue("country.\"भारत\" >= country.FRANCE", true); + assertSingleValue("country.BAHAMAS >= country.US", false); + + assertSingleValue("country.\"भारत\" > country.FRANCE", true); + assertSingleValue("country.CHINA > country.CHINA", false); + + assertSingleValue("country.\"भारत\" between country.FRANCE and country.BAHAMAS", true); + assertSingleValue("country.US between country.FRANCE and country.\"भारत\"", false); + + assertQueryFails("select country.US = mood.HAPPY", ".* '=' cannot be applied to Country.*, Mood.*"); + assertQueryFails("select country.US IN (country.CHINA, mood.SAD)", ".* All IN list values must be the same type.*"); + assertQueryFails("select country.US IN (mood.HAPPY, mood.SAD)", ".* IN value and list items must be the same type: Country"); + assertQueryFails("select country.US > 2", ".* '>' cannot be applied to Country.*, integer"); + } + + @Test + public void testLongEnumComparisonOperators() + { + assertSingleValue("mood.HAPPY = CAST(0 AS mood)", true); + assertSingleValue("mood.HAPPY = mood.SAD", false); + + assertSingleValue("mood.SAD != mood.MELLOW", true); + assertSingleValue("array[mood.HAPPY, mood.SAD] != array[mood.SAD, mood.HAPPY]", true); + + assertSingleValue("mood.SAD IN (mood.HAPPY, null, mood.SAD)", true); + assertSingleValue("mood.HAPPY IN (mood.SAD, mood.MELLOW)", false); + + assertSingleValue("mood.CURIOUS < mood.MELLOW", true); + assertSingleValue("mood.SAD < mood.HAPPY", false); + + assertSingleValue("mood.HAPPY <= mood.HAPPY", true); + assertSingleValue("mood.HAPPY <= mood.CURIOUS", false); + + assertSingleValue("mood.MELLOW >= mood.SAD", true); + assertSingleValue("mood.HAPPY >= mood.SAD", false); + + assertSingleValue("mood.SAD > mood.HAPPY", true); + assertSingleValue("mood.HAPPY > mood.HAPPY", false); + + assertSingleValue("mood.HAPPY between mood.CURIOUS and mood.SAD ", true); + assertSingleValue("mood.MELLOW between mood.SAD and mood.HAPPY", false); + + assertQueryFails("select mood.HAPPY = 3", ".* '=' cannot be applied to Mood.*, integer"); + } + + @Test + public void testEnumHashOperators() + { + assertQueryResultUnordered( + "SELECT DISTINCT x " + + "FROM (VALUES mood.happy, mood.sad, mood.sad, mood.happy) t(x)", + ImmutableList.of( + ImmutableList.of(0L), + ImmutableList.of(1L))); + + assertQueryResultUnordered( + "SELECT DISTINCT x " + + "FROM (VALUES country.FRANCE, country.FRANCE, country.\"भारत\") t(x)", + ImmutableList.of( + ImmutableList.of("France"), + ImmutableList.of("India"))); + + assertQueryResultUnordered( + "SELECT APPROX_DISTINCT(x), APPROX_DISTINCT(y)" + + "FROM (VALUES (country.FRANCE, mood.HAPPY), " + + " (country.FRANCE, mood.SAD)," + + " (country.US, mood.HAPPY)) t(x, y)", + ImmutableList.of( + ImmutableList.of(2L, 2L))); + } + + @Test + public void testEnumAggregation() + { + assertQueryResultUnordered( + " SELECT a, ARRAY_AGG(DISTINCT b) " + + "FROM (VALUES (mood.happy, country.us), " + + " (mood.happy, country.china)," + + " (mood.happy, country.CHINA)," + + " (mood.sad, country.us)) t(a, b)" + + "GROUP BY a", + ImmutableList.of( + ImmutableList.of(0L, ImmutableList.of("United States", "中国")), + ImmutableList.of(1L, ImmutableList.of("United States")))); + } + + @Test + public void testEnumJoin() + { + assertQueryResultUnordered( + " SELECT t1.a, t2.b " + + "FROM (VALUES mood.happy, mood.sad, mood.mellow) t1(a) " + + "JOIN (VALUES (mood.sad, 'hello'), (mood.happy, 'world')) t2(a, b) " + + "ON t1.a = t2.a", + ImmutableList.of( + ImmutableList.of(1L, "hello"), + ImmutableList.of(0L, "world"))); + } + + @Test + public void testEnumWindow() + { + assertQueryResultUnordered( + " SELECT first_value(b) OVER (PARTITION BY a ORDER BY a) AS rnk " + + "FROM (VALUES (mood.happy, 1), (mood.happy, 3), (mood.sad, 5)) t(a, b)", + ImmutableList.of(singletonList(1), singletonList(1), singletonList(5))); + } +}