From b4ce89587ed7c704be389a9aa16b0e02fde59ff9 Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Tue, 8 Nov 2022 10:24:00 -0800 Subject: [PATCH] Fix ClassCastException in DomainTranslator When theres's an expression such as: CAST('2022-01-01') AS date) BETWEEN CAST(start_date AS date) AND CAST(end_date AS date) There's a call to visitComparisonExpression with the term: DATE '2022-01-01' >= CAST(start_date AS date) Inside that method, the expression is normalized to have the symbol on the left and the constant on the right. However, the createVarcharCastToDateComparisonExtractionResult pulls the elements from the unnormalized ComparisonExpression node and expects the left subexpression to be cast, which results in a failure due to ClassCastException --- .../trino/sql/planner/DomainTranslator.java | 25 ++++++++-------- .../sql/planner/TestDomainTranslator.java | 29 +++++++++++++++++++ 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java index fbf5ff056f06..f03ecff9785e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java @@ -512,10 +512,10 @@ protected ExtractionResult visitComparisonExpression(ComparisonExpression node, Type castTargetType = requireNonNull(expressionTypes.get(NodeRef.of(castExpression)), "No type for Cast target expression"); if (castSourceType instanceof VarcharType && castTargetType == DATE && !castExpression.isSafe()) { Optional result = createVarcharCastToDateComparisonExtractionResult( - node, + normalized, (VarcharType) castSourceType, - normalized.getValue(), - complement); + complement, + node); if (result.isPresent()) { return result.get(); } @@ -604,15 +604,14 @@ private Map, Type> analyzeExpression(Expression expression) } private Optional createVarcharCastToDateComparisonExtractionResult( - ComparisonExpression node, + NormalizedSimpleComparison comparison, VarcharType sourceType, - NullableValue value, - boolean complement) + boolean complement, + ComparisonExpression originalExpression) { - Cast castExpression = (Cast) node.getLeft(); - Expression sourceExpression = castExpression.getExpression(); - ComparisonExpression.Operator comparisonOperator = node.getOperator(); - requireNonNull(value, "value is null"); + Expression sourceExpression = ((Cast) comparison.getSymbolExpression()).getExpression(); + ComparisonExpression.Operator operator = comparison.getComparisonOperator(); + NullableValue value = comparison.getValue(); if (complement || value.isNull()) { return Optional.empty(); @@ -638,7 +637,7 @@ private Optional createVarcharCastToDateComparisonExtractionRe ValueSet valueSet; boolean nullAllowed = false; - switch (comparisonOperator) { + switch (operator) { case EQUAL: valueSet = dateStringRanges(date, sourceType); break; @@ -649,7 +648,7 @@ private Optional createVarcharCastToDateComparisonExtractionRe return Optional.empty(); } valueSet = ValueSet.all(sourceType).subtract(dateStringRanges(date, sourceType)); - nullAllowed = (comparisonOperator == IS_DISTINCT_FROM); + nullAllowed = (operator == IS_DISTINCT_FROM); break; case LESS_THAN: case LESS_THAN_OR_EQUAL: @@ -670,7 +669,7 @@ private Optional createVarcharCastToDateComparisonExtractionRe return Optional.of(new ExtractionResult( TupleDomain.withColumnDomains(ImmutableMap.of(sourceSymbol, Domain.create(valueSet, nullAllowed))), - node)); + originalExpression)); } /** diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java index f4963d1c5579..c6b2b9355e56 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestDomainTranslator.java @@ -1072,6 +1072,17 @@ public void testPredicateWithVarcharCastToDate() Range.greaterThan(VARCHAR, utf8Slice("2004"))), false))); + // Regression test for https://github.com/trinodb/trino/issues/14954 + assertPredicateTranslates( + greaterThan(new GenericLiteral("DATE", "2001-01-31"), cast(C_VARCHAR, DATE)), + tupleDomain( + C_VARCHAR, + Domain.create(ValueSet.ofRanges( + Range.lessThan(VARCHAR, utf8Slice("2002")), + Range.greaterThan(VARCHAR, utf8Slice("9"))), + false)), + greaterThan(new GenericLiteral("DATE", "2001-01-31"), cast(C_VARCHAR, DATE))); + // BETWEEN assertPredicateTranslates( between(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", "2001-01-31"), new GenericLiteral("DATE", "2005-09-10")), @@ -1083,6 +1094,24 @@ public void testPredicateWithVarcharCastToDate() and( greaterThanOrEqual(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", "2001-01-31")), lessThanOrEqual(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", "2005-09-10")))); + + // Regression test for https://github.com/trinodb/trino/issues/14954 + assertPredicateTranslates( + between(new GenericLiteral("DATE", "2001-01-31"), cast(C_VARCHAR, DATE), cast(C_VARCHAR_1, DATE)), + tupleDomain( + C_VARCHAR, + Domain.create(ValueSet.ofRanges( + Range.lessThan(VARCHAR, utf8Slice("2002")), + Range.greaterThan(VARCHAR, utf8Slice("9"))), + false), + C_VARCHAR_1, + Domain.create(ValueSet.ofRanges( + Range.lessThan(VARCHAR, utf8Slice("1")), + Range.greaterThan(VARCHAR, utf8Slice("2000"))), + false)), + and( + greaterThanOrEqual(new GenericLiteral("DATE", "2001-01-31"), cast(C_VARCHAR, DATE)), + lessThanOrEqual(new GenericLiteral("DATE", "2001-01-31"), cast(C_VARCHAR_1, DATE)))); } @Test