Skip to content

Commit

Permalink
Validate literals in expression analyzer
Browse files Browse the repository at this point in the history
As a result, optimizer can safely assume a `Literal` to represent a
value of given type. Also, queries with invalid literal values should be
guaranteed to fail, regardless of expression pruning.
  • Loading branch information
findepi committed Jan 27, 2022
1 parent 2f21df7 commit c8540c7
Show file tree
Hide file tree
Showing 12 changed files with 201 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import io.trino.sql.analyzer.Analysis.Range;
import io.trino.sql.analyzer.Analysis.ResolvedWindow;
import io.trino.sql.analyzer.PatternRecognitionAnalyzer.PatternRecognitionAnalysis;
import io.trino.sql.planner.LiteralInterpreter;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.tree.ArithmeticBinaryExpression;
Expand Down Expand Up @@ -1016,7 +1017,13 @@ protected Type visitDoubleLiteral(DoubleLiteral node, StackableAstVisitorContext
@Override
protected Type visitDecimalLiteral(DecimalLiteral node, StackableAstVisitorContext<Context> context)
{
DecimalParseResult parseResult = Decimals.parse(node.getValue());
DecimalParseResult parseResult;
try {
parseResult = Decimals.parse(node.getValue());
}
catch (RuntimeException e) {
throw semanticException(INVALID_LITERAL, node, e, "'%s' is not a valid decimal literal", node.getValue());
}
return setExpressionType(node, parseResult.getType());
}

Expand Down Expand Up @@ -1045,6 +1052,12 @@ protected Type visitGenericLiteral(GenericLiteral node, StackableAstVisitorConte
throw semanticException(INVALID_LITERAL, node, "No literal form for type %s", type);
}
}
try {
LiteralInterpreter.evaluate(plannerContext, session, ImmutableMap.of(NodeRef.of(node), type), node);
}
catch (RuntimeException e) {
throw semanticException(INVALID_LITERAL, node, e, "'%s' is not a valid %s literal", node.getValue(), type.getDisplayName());
}

return setExpressionType(node, type);
}
Expand Down Expand Up @@ -1111,6 +1124,12 @@ protected Type visitIntervalLiteral(IntervalLiteral node, StackableAstVisitorCon
else {
type = INTERVAL_DAY_TIME;
}
try {
LiteralInterpreter.evaluate(plannerContext, session, ImmutableMap.of(NodeRef.of(node), type), node);
}
catch (RuntimeException e) {
throw semanticException(INVALID_LITERAL, node, e, "'%s' is not a valid interval literal", node.getValue());
}
return setExpressionType(node, type);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.StandardErrorCode.INVALID_LITERAL;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.VarcharType.VARCHAR;
Expand Down Expand Up @@ -275,7 +276,7 @@ public void testJsonArrayContainsInvalid()
@Test
public void testInvalidJsonParse()
{
assertInvalidFunction("JSON 'INVALID'", INVALID_FUNCTION_ARGUMENT);
assertInvalidFunction("JSON 'INVALID'", INVALID_LITERAL);
assertInvalidFunction("JSON_PARSE('INVALID')", INVALID_FUNCTION_ARGUMENT);
assertInvalidFunction("JSON_PARSE('\"x\": 1')", INVALID_FUNCTION_ARGUMENT);
assertInvalidFunction("JSON_PARSE('{}{')", INVALID_FUNCTION_ARGUMENT);
Expand Down
154 changes: 144 additions & 10 deletions core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -3250,19 +3250,153 @@ public void testIfInJoinClause()
@Test
public void testLiteral()
{
// boolean
assertFails("SELECT BOOLEAN '2'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT BOOLEAN 'a'")
.hasErrorCode(INVALID_LITERAL);

// tinyint
assertFails("SELECT TINYINT ''")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT TINYINT '128'") // max value + 1
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT TINYINT '-129'") // min value - 1
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT TINYINT '12.1'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT TINYINT 'a'")
.hasErrorCode(INVALID_LITERAL);

// smallint
assertFails("SELECT SMALLINT ''")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT SMALLINT '2147483648'") // max value + 1
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT SMALLINT '-2147483649'") // min value - 1
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT SMALLINT '12.1'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT SMALLINT 'a'")
.hasErrorCode(INVALID_LITERAL);

// integer
assertFails("SELECT INTEGER ''")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT INTEGER '2147483648'") // max value + 1
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT INTEGER '-2147483649'") // min value - 1
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT INTEGER '12.1'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT INTEGER 'a'")
.hasErrorCode(INVALID_LITERAL);

// bigint
assertFails("SELECT BIGINT ''")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT BIGINT '9223372036854775808'") // max value + 1
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT BIGINT '-9223372036854775809'") // min value - 1
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT BIGINT '12.1'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT BIGINT 'a'")
.hasErrorCode(INVALID_LITERAL);

// real
assertFails("SELECT REAL ''")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT REAL '1.2.3'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT REAL 'a'")
.hasErrorCode(INVALID_LITERAL);

// double
assertFails("SELECT DOUBLE ''")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT DOUBLE '1.2.3'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT DOUBLE 'a'")
.hasErrorCode(INVALID_LITERAL);

// decimal
assertFails("SELECT 1234567890123456789012.34567890123456789") // 39 digits, decimal point
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT 0.123456789012345678901234567890123456789") // 39 digits after "0."
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT .123456789012345678901234567890123456789") // 39 digits after "."
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT DECIMAL ''")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT DECIMAL '123456789012345678901234567890123456789'") // 39 digits, no decimal point
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT DECIMAL '1234567890123456789012.34567890123456789'") // 39 digits, decimal point
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT DECIMAL '0.123456789012345678901234567890123456789'") // 39 digits after "0."
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT DECIMAL '.123456789012345678901234567890123456789'") // 39 digits after "."
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT DECIMAL 'a'")
.hasErrorCode(INVALID_LITERAL);

// date
assertFails("SELECT DATE '20220101'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT DATE 'a'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT DATE 'today'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT DATE '2022-01-01 UTC'")
.hasErrorCode(INVALID_LITERAL);

// time
assertFails("SELECT TIME ''")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT TIME '12'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT TIME '1234567'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT TIME 'a'")
.hasErrorCode(INVALID_LITERAL);

// timestamp
assertFails("SELECT TIMESTAMP ''")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT TIMESTAMP '2012-10-31 01:00:00 PT'")
.hasErrorCode(INVALID_LITERAL);
}
assertFails("SELECT TIMESTAMP 'a'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT TIMESTAMP 'now'")
.hasErrorCode(INVALID_LITERAL);

@Test
public void testJsonLiteral()
{
// TODO All the below should fail. Literals should be validated during analysis https://github.com/trinodb/trino/issues/10719
analyze("SELECT JSON '{}{'");
analyze("SELECT JSON '{} \"a\"'");
analyze("SELECT JSON '{}{abc'");
analyze("SELECT JSON '{}abc'");
analyze("SELECT JSON ''");
// interval
assertFails("SELECT INTERVAL 'a' DAY TO SECOND")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT INTERVAL '12.1' DAY TO SECOND")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT INTERVAL '12' YEAR TO DAY")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT INTERVAL '12' SECOND TO MINUTE")
.hasErrorCode(INVALID_LITERAL);

// json
assertFails("SELECT JSON ''")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT JSON '{}{'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT JSON '{} \"a\"'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT JSON '{}{'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT JSON '{} \"a\"'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT JSON '{}{abc'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT JSON '{}abc'")
.hasErrorCode(INVALID_LITERAL);
assertFails("SELECT JSON ''")
.hasErrorCode(INVALID_LITERAL);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import java.util.concurrent.TimeUnit;

import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static io.trino.spi.StandardErrorCode.INVALID_LITERAL;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DateType.DATE;
import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey;
Expand Down Expand Up @@ -261,8 +261,8 @@ public void testIsDistinctFrom()
public void testDateCastFromVarchar()
{
assertFunction("DATE '2013-02-02'", DATE, toDate(new DateTime(2013, 2, 2, 0, 0, 0, 0, UTC)));
assertInvalidFunction("DATE '5881580-07-12'", INVALID_CAST_ARGUMENT, "Value cannot be cast to date: 5881580-07-12");
assertInvalidFunction("DATE '392251590-07-12'", INVALID_CAST_ARGUMENT, "Value cannot be cast to date: 392251590-07-12");
assertInvalidFunction("DATE '5881580-07-12'", INVALID_LITERAL, "line 1:1: '5881580-07-12' is not a valid date literal");
assertInvalidFunction("DATE '392251590-07-12'", INVALID_LITERAL, "line 1:1: '392251590-07-12' is not a valid date literal");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.testng.annotations.Test;

import static io.trino.spi.StandardErrorCode.DIVISION_BY_ZERO;
import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static io.trino.spi.StandardErrorCode.INVALID_LITERAL;
import static io.trino.spi.function.OperatorType.INDETERMINATE;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
Expand All @@ -38,7 +38,7 @@ public void testLiteral()
{
assertFunction("INTEGER '37'", INTEGER, 37);
assertFunction("INTEGER '17'", INTEGER, 17);
assertInvalidCast("INTEGER '" + ((long) Integer.MAX_VALUE + 1L) + "'");
assertInvalidFunction("INTEGER '" + ((long) Integer.MAX_VALUE + 1L) + "'", INVALID_LITERAL);
}

@Test
Expand All @@ -53,7 +53,7 @@ public void testUnaryMinus()
{
assertFunction("INTEGER '-37'", INTEGER, -37);
assertFunction("INTEGER '-17'", INTEGER, -17);
assertInvalidFunction("INTEGER '-" + Integer.MIN_VALUE + "'", INVALID_CAST_ARGUMENT);
assertInvalidFunction("INTEGER '-" + Integer.MIN_VALUE + "'", INVALID_LITERAL);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.testng.annotations.Test;

import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.StandardErrorCode.INVALID_LITERAL;
import static io.trino.spi.function.OperatorType.INDETERMINATE;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.VarcharType.VARCHAR;
Expand Down Expand Up @@ -97,11 +98,11 @@ private void assertLiteral(String projection, Type expectedType, SqlIntervalDayT
@Test
public void testInvalidLiteral()
{
assertInvalidFunction("INTERVAL '12X' DAY", "Invalid INTERVAL DAY value: 12X");
assertInvalidFunction("INTERVAL '12 10' DAY", "Invalid INTERVAL DAY value: 12 10");
assertInvalidFunction("INTERVAL '12 X' DAY TO HOUR", "Invalid INTERVAL DAY TO HOUR value: 12 X");
assertInvalidFunction("INTERVAL '12 -10' DAY TO HOUR", "Invalid INTERVAL DAY TO HOUR value: 12 -10");
assertInvalidFunction("INTERVAL '--12 -10' DAY TO HOUR", "Invalid INTERVAL DAY TO HOUR value: --12 -10");
assertInvalidFunction("INTERVAL '12X' DAY", INVALID_LITERAL, "line 1:1: '12X' is not a valid interval literal");
assertInvalidFunction("INTERVAL '12 10' DAY", INVALID_LITERAL, "line 1:1: '12 10' is not a valid interval literal");
assertInvalidFunction("INTERVAL '12 X' DAY TO HOUR", INVALID_LITERAL, "line 1:1: '12 X' is not a valid interval literal");
assertInvalidFunction("INTERVAL '12 -10' DAY TO HOUR", INVALID_LITERAL, "line 1:1: '12 -10' is not a valid interval literal");
assertInvalidFunction("INTERVAL '--12 -10' DAY TO HOUR", INVALID_LITERAL, "line 1:1: '--12 -10' is not a valid interval literal");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.testng.annotations.Test;

import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.StandardErrorCode.INVALID_LITERAL;
import static io.trino.spi.function.OperatorType.INDETERMINATE;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.VarcharType.VARCHAR;
Expand Down Expand Up @@ -70,11 +71,11 @@ private void assertLiteral(String projection, Type expectedType, SqlIntervalYear
@Test
public void testInvalidLiteral()
{
assertInvalidFunction("INTERVAL '124X' YEAR", "Invalid INTERVAL YEAR value: 124X");
assertInvalidFunction("INTERVAL '124-30' YEAR", "Invalid INTERVAL YEAR value: 124-30");
assertInvalidFunction("INTERVAL '124-X' YEAR TO MONTH", "Invalid INTERVAL YEAR TO MONTH value: 124-X");
assertInvalidFunction("INTERVAL '124--30' YEAR TO MONTH", "Invalid INTERVAL YEAR TO MONTH value: 124--30");
assertInvalidFunction("INTERVAL '--124--30' YEAR TO MONTH", "Invalid INTERVAL YEAR TO MONTH value: --124--30");
assertInvalidFunction("INTERVAL '124X' YEAR", INVALID_LITERAL, "line 1:1: '124X' is not a valid interval literal");
assertInvalidFunction("INTERVAL '124-30' YEAR", INVALID_LITERAL, "line 1:1: '124-30' is not a valid interval literal");
assertInvalidFunction("INTERVAL '124-X' YEAR TO MONTH", INVALID_LITERAL, "line 1:1: '124-X' is not a valid interval literal");
assertInvalidFunction("INTERVAL '124--30' YEAR TO MONTH", INVALID_LITERAL, "line 1:1: '124--30' is not a valid interval literal");
assertInvalidFunction("INTERVAL '--124--30' YEAR TO MONTH", INVALID_LITERAL, "line 1:1: '--124--30' is not a valid interval literal");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static io.trino.SessionTestUtils.TEST_SESSION;
import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.StandardErrorCode.INVALID_LITERAL;
import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH;
import static io.trino.spi.function.OperatorType.INDETERMINATE;
import static io.trino.spi.type.BigintType.BIGINT;
Expand Down Expand Up @@ -180,11 +181,11 @@ public void testTypeConstructor()
assertFunction("JSON '[null]'", JSON, "[null]");
assertFunction("JSON '[13,null,42]'", JSON, "[13,null,42]");
assertFunction("JSON '{\"x\": null}'", JSON, "{\"x\":null}");
assertInvalidFunction("JSON '{}{'", INVALID_FUNCTION_ARGUMENT);
assertInvalidFunction("JSON '{} \"a\"'", INVALID_FUNCTION_ARGUMENT);
assertInvalidFunction("JSON '{}{abc'", INVALID_FUNCTION_ARGUMENT);
assertInvalidFunction("JSON '{}abc'", INVALID_FUNCTION_ARGUMENT);
assertInvalidFunction("JSON ''", INVALID_FUNCTION_ARGUMENT);
assertInvalidFunction("JSON '{}{'", INVALID_LITERAL);
assertInvalidFunction("JSON '{} \"a\"'", INVALID_LITERAL);
assertInvalidFunction("JSON '{}{abc'", INVALID_LITERAL);
assertInvalidFunction("JSON '{}abc'", INVALID_LITERAL);
assertInvalidFunction("JSON ''", INVALID_LITERAL);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.testng.annotations.Test;

import static io.trino.spi.StandardErrorCode.DIVISION_BY_ZERO;
import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static io.trino.spi.StandardErrorCode.INVALID_LITERAL;
import static io.trino.spi.function.OperatorType.INDETERMINATE;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
Expand All @@ -38,7 +38,7 @@ public void testLiteral()
{
assertFunction("SMALLINT '37'", SMALLINT, (short) 37);
assertFunction("SMALLINT '17'", SMALLINT, (short) 17);
assertInvalidCast("SMALLINT '" + ((long) Short.MAX_VALUE + 1L) + "'");
assertInvalidFunction("SMALLINT '" + ((long) Short.MAX_VALUE + 1L) + "'", INVALID_LITERAL);
}

@Test
Expand All @@ -53,7 +53,7 @@ public void testUnaryMinus()
{
assertFunction("SMALLINT '-37'", SMALLINT, (short) -37);
assertFunction("SMALLINT '-17'", SMALLINT, (short) -17);
assertInvalidFunction("SMALLINT '-" + Short.MIN_VALUE + "'", INVALID_CAST_ARGUMENT);
assertInvalidFunction("SMALLINT '-" + Short.MIN_VALUE + "'", INVALID_LITERAL);
}

@Test
Expand Down
Loading

0 comments on commit c8540c7

Please sign in to comment.