diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/QueryBuilder.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/QueryBuilder.java index 2964e80cd801..5c4d82598c4f 100644 --- a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/QueryBuilder.java +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/QueryBuilder.java @@ -23,6 +23,7 @@ import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.Range; import io.prestosql.spi.predicate.TupleDomain; +import io.prestosql.spi.predicate.ValueSet; import io.prestosql.spi.type.Type; import java.sql.Connection; @@ -34,6 +35,7 @@ import java.util.Optional; import java.util.function.Function; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.Iterables.getOnlyElement; @@ -199,9 +201,27 @@ private String toPredicate(JdbcColumnHandle column, Domain domain, List accumulator) + { + checkArgument(!valueSet.isNone(), "none values should be handled earlier"); + + if (!valueSet.isDiscreteSet()) { + ValueSet complement = valueSet.complement(); + if (complement.isDiscreteSet()) { + return format("NOT (%s)", toPredicate(column, complement, accumulator)); + } + } + List disjuncts = new ArrayList<>(); List singleValues = new ArrayList<>(); - for (Range range : domain.getValues().getRanges().getOrderedRanges()) { + for (Range range : valueSet.getRanges().getOrderedRanges()) { checkState(!range.isAll()); // Already checked if (range.isSingleValue()) { singleValues.add(range.getLow().getValue()); @@ -238,7 +258,12 @@ private String toPredicate(JdbcColumnHandle column, Domain domain, List 1) { disjuncts.add(client.quoted(column.getColumnName()) + " IN (" + values + ")"); } - // Add nullability disjuncts checkState(!disjuncts.isEmpty()); - if (domain.isNullAllowed()) { - disjuncts.add(client.quoted(column.getColumnName()) + " IS NULL"); + if (disjuncts.size() == 1) { + return getOnlyElement(disjuncts); } - return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; } diff --git a/presto-base-jdbc/src/test/java/io/prestosql/plugin/jdbc/TestJdbcQueryBuilder.java b/presto-base-jdbc/src/test/java/io/prestosql/plugin/jdbc/TestJdbcQueryBuilder.java index 588136803ce6..cfd8ab19f136 100644 --- a/presto-base-jdbc/src/test/java/io/prestosql/plugin/jdbc/TestJdbcQueryBuilder.java +++ b/presto-base-jdbc/src/test/java/io/prestosql/plugin/jdbc/TestJdbcQueryBuilder.java @@ -47,7 +47,9 @@ import java.util.Locale; import java.util.Optional; import java.util.function.Function; +import java.util.stream.LongStream; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.testing.Assertions.assertContains; import static io.prestosql.plugin.jdbc.TestingJdbcTypeHandle.JDBC_BIGINT; @@ -75,7 +77,6 @@ import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.testing.DateTimeTestingUtils.sqlTimeOf; import static io.prestosql.testing.DateTimeTestingUtils.sqlTimestampOf; -import static io.prestosql.testing.TestingConnectorSession.SESSION; import static io.prestosql.type.DateTimes.MICROSECONDS_PER_MILLISECOND; import static java.lang.Float.floatToRawIntBits; import static java.lang.String.format; @@ -237,16 +238,71 @@ public void testNormalBuildSql() .build()); Connection connection = database.getConnection(); - try (PreparedStatement preparedStatement = new QueryBuilder(jdbcClient).buildSql(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty(), identity()); - ResultSet resultSet = preparedStatement.executeQuery()) { + try (PreparedStatement preparedStatement = new QueryBuilder(jdbcClient).buildSql(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty(), identity())) { + assertThat(lastQuery).isEqualTo("" + + "SELECT \"col_0\" AS \"col_0\", \"col_1\" AS \"col_1\", \"col_2\" AS \"col_2\", \"col_3\" AS \"col_3\", \"col_4\" AS \"col_4\", \"col_5\" AS \"col_5\", " + + "\"col_6\" AS \"col_6\", \"col_7\" AS \"col_7\", \"col_8\" AS \"col_8\", \"col_9\" AS \"col_9\", \"col_10\" AS \"col_10\", \"col_11\" AS \"col_11\" " + + "FROM \"test_table\" " + + "WHERE (\"col_0\" < ? OR (\"col_0\" >= ? AND \"col_0\" <= ?) OR \"col_0\" > ? OR \"col_0\" IN (?,?)) " + + "AND ((\"col_1\" >= ? AND \"col_1\" <= ?) OR (\"col_1\" >= ? AND \"col_1\" <= ?) OR \"col_1\" IN (?,?,?,?)) " + + "AND ((\"col_7\" >= ? AND \"col_7\" < ?) OR (\"col_7\" >= ? AND \"col_7\" < ?)) " + + "AND ((\"col_8\" >= ? AND \"col_8\" < ?) OR (\"col_8\" >= ? AND \"col_8\" <= ?)) " + + "AND (\"col_9\" < ? OR \"col_9\" IN (?,?)) " + + "AND \"col_2\" = ?"); ImmutableSet.Builder builder = ImmutableSet.builder(); - while (resultSet.next()) { - builder.add((Long) resultSet.getObject("col_0")); + try (ResultSet resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + builder.add((Long) resultSet.getObject("col_0")); + } } assertEquals(builder.build(), ImmutableSet.of(68L, 180L, 196L)); } } + /** + * Test query generation for domains originating from {@code NOT IN} predicates. + * + * @see Domain#complement() + */ + @Test + public void testBuildSqlWithDomainComplement() + throws SQLException + { + TupleDomain tupleDomain = TupleDomain.withColumnDomains(ImmutableMap.builder() + // complement of a Domain with null not allowed + .put(columns.get(0), Domain.create(ValueSet.of(BIGINT, 128L, 180L, 233L), false).complement()) + // complement of a Domain with null allowed + .put(columns.get(1), Domain.create(ValueSet.of(DOUBLE, 200011.0, 200014.0, 200017.0), true).complement()) + // this is here only to limit the list of results being read + .put(columns.get(9), Domain.create(ValueSet.ofRanges(Range.greaterThanOrEqual(INTEGER, 880L)), false)) + .build()); + + Connection connection = database.getConnection(); + try (PreparedStatement preparedStatement = new QueryBuilder(jdbcClient).buildSql( + SESSION, + connection, + TEST_TABLE, + Optional.empty(), + List.of(columns.get(0), columns.get(3), columns.get(9)), + tupleDomain, + Optional.empty(), + identity())) { + assertThat(lastQuery).isEqualTo("" + + "SELECT \"col_0\" AS \"col_0\", \"col_3\" AS \"col_3\", \"col_9\" AS \"col_9\" " + + "FROM \"test_table\" " + + "WHERE (NOT (\"col_0\" IN (?,?,?)) OR \"col_0\" IS NULL) " + + "AND NOT (\"col_1\" IN (?,?,?)) " + + "AND \"col_9\" >= ?"); + ImmutableSet.Builder builder = ImmutableSet.builder(); + try (ResultSet resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + builder.add((Long) resultSet.getObject("col_0")); + } + } + assertEquals(builder.build(), LongStream.range(980, 1000).boxed().collect(toImmutableList())); + } + } + @Test public void testBuildSqlWithFloat() throws SQLException @@ -260,13 +316,19 @@ public void testBuildSqlWithFloat() false))); Connection connection = database.getConnection(); - try (PreparedStatement preparedStatement = new QueryBuilder(jdbcClient).buildSql(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty(), identity()); - ResultSet resultSet = preparedStatement.executeQuery()) { + try (PreparedStatement preparedStatement = new QueryBuilder(jdbcClient).buildSql(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty(), identity())) { + assertThat(lastQuery).isEqualTo("" + + "SELECT \"col_0\" AS \"col_0\", \"col_1\" AS \"col_1\", \"col_2\" AS \"col_2\", \"col_3\" AS \"col_3\", \"col_4\" AS \"col_4\", \"col_5\" AS \"col_5\", " + + "\"col_6\" AS \"col_6\", \"col_7\" AS \"col_7\", \"col_8\" AS \"col_8\", \"col_9\" AS \"col_9\", \"col_10\" AS \"col_10\", \"col_11\" AS \"col_11\" " + + "FROM \"test_table\" " + + "WHERE \"col_10\" IN (?,?,?)"); ImmutableSet.Builder longBuilder = ImmutableSet.builder(); ImmutableSet.Builder floatBuilder = ImmutableSet.builder(); - while (resultSet.next()) { - longBuilder.add((Long) resultSet.getObject("col_0")); - floatBuilder.add((Float) resultSet.getObject("col_10")); + try (ResultSet resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + longBuilder.add((Long) resultSet.getObject("col_0")); + floatBuilder.add((Float) resultSet.getObject("col_10")); + } } assertEquals(longBuilder.build(), ImmutableSet.of(0L, 14L)); assertEquals(floatBuilder.build(), ImmutableSet.of(100.0f, 114.0f)); @@ -286,11 +348,17 @@ public void testBuildSqlWithVarchar() false))); Connection connection = database.getConnection(); - try (PreparedStatement preparedStatement = new QueryBuilder(jdbcClient).buildSql(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty(), identity()); - ResultSet resultSet = preparedStatement.executeQuery()) { + try (PreparedStatement preparedStatement = new QueryBuilder(jdbcClient).buildSql(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty(), identity())) { + assertThat(lastQuery).isEqualTo("" + + "SELECT \"col_0\" AS \"col_0\", \"col_1\" AS \"col_1\", \"col_2\" AS \"col_2\", \"col_3\" AS \"col_3\", \"col_4\" AS \"col_4\", \"col_5\" AS \"col_5\", " + + "\"col_6\" AS \"col_6\", \"col_7\" AS \"col_7\", \"col_8\" AS \"col_8\", \"col_9\" AS \"col_9\", \"col_10\" AS \"col_10\", \"col_11\" AS \"col_11\" " + + "FROM \"test_table\" " + + "WHERE ((\"col_3\" >= ? AND \"col_3\" < ?) OR \"col_3\" IN (?,?))"); ImmutableSet.Builder builder = ImmutableSet.builder(); - while (resultSet.next()) { - builder.add((String) resultSet.getObject("col_3")); + try (ResultSet resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + builder.add((String) resultSet.getObject("col_3")); + } } assertEquals(builder.build(), ImmutableSet.of("test_str_700", "test_str_701", "test_str_180", "test_str_196")); @@ -314,11 +382,17 @@ public void testBuildSqlWithChar() false))); Connection connection = database.getConnection(); - try (PreparedStatement preparedStatement = new QueryBuilder(jdbcClient).buildSql(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty(), identity()); - ResultSet resultSet = preparedStatement.executeQuery()) { + try (PreparedStatement preparedStatement = new QueryBuilder(jdbcClient).buildSql(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty(), identity())) { + assertThat(lastQuery).isEqualTo("" + + "SELECT \"col_0\" AS \"col_0\", \"col_1\" AS \"col_1\", \"col_2\" AS \"col_2\", \"col_3\" AS \"col_3\", \"col_4\" AS \"col_4\", \"col_5\" AS \"col_5\", " + + "\"col_6\" AS \"col_6\", \"col_7\" AS \"col_7\", \"col_8\" AS \"col_8\", \"col_9\" AS \"col_9\", \"col_10\" AS \"col_10\", \"col_11\" AS \"col_11\" " + + "FROM \"test_table\" " + + "WHERE ((\"col_11\" >= ? AND \"col_11\" < ?) OR \"col_11\" IN (?,?))"); ImmutableSet.Builder builder = ImmutableSet.builder(); - while (resultSet.next()) { - builder.add((String) resultSet.getObject("col_11")); + try (ResultSet resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + builder.add((String) resultSet.getObject("col_11")); + } } assertEquals(builder.build(), ImmutableSet.of("test_str_700", "test_str_701", "test_str_180", "test_str_196")); @@ -347,13 +421,19 @@ public void testBuildSqlWithDateTime() false))); Connection connection = database.getConnection(); - try (PreparedStatement preparedStatement = new QueryBuilder(jdbcClient).buildSql(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty(), identity()); - ResultSet resultSet = preparedStatement.executeQuery()) { + try (PreparedStatement preparedStatement = new QueryBuilder(jdbcClient).buildSql(SESSION, connection, TEST_TABLE, Optional.empty(), columns, tupleDomain, Optional.empty(), identity())) { + assertThat(lastQuery).isEqualTo("" + + "SELECT \"col_0\" AS \"col_0\", \"col_1\" AS \"col_1\", \"col_2\" AS \"col_2\", \"col_3\" AS \"col_3\", \"col_4\" AS \"col_4\", \"col_5\" AS \"col_5\", " + + "\"col_6\" AS \"col_6\", \"col_7\" AS \"col_7\", \"col_8\" AS \"col_8\", \"col_9\" AS \"col_9\", \"col_10\" AS \"col_10\", \"col_11\" AS \"col_11\" " + + "FROM \"test_table\" " + + "WHERE ((\"col_4\" >= ? AND \"col_4\" < ?) OR \"col_4\" IN (?,?)) AND ((\"col_5\" > ? AND \"col_5\" <= ?) OR \"col_5\" IN (?,?))"); ImmutableSet.Builder dateBuilder = ImmutableSet.builder(); ImmutableSet.Builder