Skip to content

Commit

Permalink
[multistage] bridge v2 query engine for leaf stage v1 multi-value col…
Browse files Browse the repository at this point in the history
…umn (#11117)

* [multistage] bridge v2 query engine for leaf stage v1 group by multi-value column

* use multi-set

* Change multi-value type back to array

* rewrite arrayToMV at leaf stage

* Enable more tests

* fix integration tests with generated queries

* Address comments

* Take out MultiValueBetweenPredicateGenerator from _multistageSingleValuePredicateGenerators
  • Loading branch information
xiangfu0 authored Jul 26, 2023
1 parent a86ba9c commit a0ff2e8
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,11 @@ public static Object clpDecode(String logtypeFieldName, String dictVarsFieldName
String defaultValue) {
throw new UnsupportedOperationException("Placeholder scalar function, should not reach here");
}

@ScalarFunction(names = {"arrayToMV", "array_to_mv"},
isPlaceholder = true)
public static String arrayToMV(Object multiValue) {
throw new UnsupportedOperationException("Placeholder scalar function, should not reach here");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ public enum TransformFunctionType {
// date type conversion functions
CAST("cast"),

// object type
ARRAY_TO_MV("arrayToMV",
ReturnTypes.cascade(opBinding -> positionalComponentReturnType(opBinding, 0), SqlTypeTransforms.FORCE_NULLABLE),
OperandTypes.family(SqlTypeFamily.ARRAY), "array_to_mv"),

// string functions
JSONEXTRACTSCALAR("jsonExtractScalar",
ReturnTypes.cascade(opBinding -> positionalReturnTypeInferenceFromStringLiteral(opBinding, 2,
Expand Down Expand Up @@ -280,6 +285,13 @@ private static RelDataType positionalReturnTypeInferenceFromStringLiteral(SqlOpe
return opBinding.getTypeFactory().createSqlType(defaultSqlType);
}

private static RelDataType positionalComponentReturnType(SqlOperatorBinding opBinding, int pos) {
if (opBinding.getOperandCount() > pos) {
return opBinding.getOperandType(pos).getComponentType();
}
throw new IllegalArgumentException("Invalid number of arguments for function " + opBinding.getOperator().getName());
}

private static RelDataType dateTimeConverterReturnTypeInference(SqlOperatorBinding opBinding) {
int outputFormatPos = 2;
if (opBinding.getOperandCount() > outputFormatPos
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ public static TransformFunction get(ExpressionContext expression, Map<String, Co
return new IdentifierTransformFunction(columnName, columnContextMap.get(columnName));
case LITERAL:
return queryContext.getOrComputeSharedValue(LiteralTransformFunction.class, expression.getLiteral(),
LiteralTransformFunction::new);
LiteralTransformFunction::new);
default:
throw new IllegalStateException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ public class QueryGenerator {
private final List<PredicateGenerator> _multistageSingleValuePredicateGenerators =
Arrays.asList(new SingleValueComparisonPredicateGenerator(), new SingleValueInPredicateGenerator(),
new SingleValueBetweenPredicateGenerator());
// TODO: add MultiValueBetweenPredicateGenerator back once the BETWEEEN AND operator is supported in multistage engine
private final List<PredicateGenerator> _multiValuePredicateGenerators =
Arrays.asList(new MultiValueComparisonPredicateGenerator(), new MultiValueInPredicateGenerator(),
new MultiValueBetweenPredicateGenerator());
Arrays.asList(new MultiValueComparisonPredicateGenerator(), new MultiValueInPredicateGenerator());

private final String _pinotTableName;
private final String _h2TableName;
Expand Down Expand Up @@ -351,10 +351,12 @@ private PredicateQueryFragment generatePredicate() {
if (!_columnToValueList.get(columnName).isEmpty()) {
if (!_multiValueColumnMaxNumElements.containsKey(columnName)) {
// Single-value column.
predicates.add(pickRandom(getSingleValuePredicateGenerators()).generatePredicate(columnName));
predicates.add(pickRandom(getSingleValuePredicateGenerators()).generatePredicate(columnName,
_useMultistageEngine));
} else if (!_skipMultiValuePredicates) {
// Multi-value column.
predicates.add(pickRandom(_multiValuePredicateGenerators).generatePredicate(columnName));
predicates.add(
pickRandom(_multiValuePredicateGenerators).generatePredicate(columnName, _useMultistageEngine));
}
}
}
Expand Down Expand Up @@ -407,10 +409,11 @@ private interface PredicateGenerator {
/**
* Generate a predicate query fragment on a column.
*
* @param columnName column name.
* @param columnName column name.
* @param useMultistageEngine
* @return generated predicate query fragment.
*/
QueryFragment generatePredicate(String columnName);
QueryFragment generatePredicate(String columnName, boolean useMultistageEngine);
}

/**
Expand Down Expand Up @@ -485,16 +488,47 @@ public AggregationQuery(List<String> aggregateColumnsAndFunctions, PredicateQuer

@Override
public String generatePinotQuery() {
List<String> pinotAggregateColumnAndFunctions =
(_useMultistageEngine && !_skipMultiValuePredicates) ? generatePinotMultistageQuery()
: _aggregateColumnsAndFunctions;
if (_groupColumns.isEmpty()) {
return joinWithSpaces("SELECT", StringUtils.join(_aggregateColumnsAndFunctions, ", "), "FROM", _pinotTableName,
_predicate.generatePinotQuery());
return joinWithSpaces("SELECT", StringUtils.join(pinotAggregateColumnAndFunctions, ", "), "FROM",
_pinotTableName, _predicate.generatePinotQuery());
} else {
return joinWithSpaces("SELECT", StringUtils.join(_aggregateColumnsAndFunctions, ", "), "FROM", _pinotTableName,
_predicate.generatePinotQuery(), "GROUP BY", StringUtils.join(_groupColumns, ", "),
return joinWithSpaces("SELECT", StringUtils.join(pinotAggregateColumnAndFunctions, ", "), "FROM",
_pinotTableName, _predicate.generatePinotQuery(), "GROUP BY", StringUtils.join(_groupColumns, ", "),
_havingPredicate.generatePinotQuery(), _limit.generatePinotQuery());
}
}

public List<String> generatePinotMultistageQuery() {
List<String> pinotAggregateColumnAndFunctions = new ArrayList<>();
for (String aggregateColumnAndFunction : _aggregateColumnsAndFunctions) {
String pinotAggregateFunction = aggregateColumnAndFunction;
String pinotAggregateColumnAndFunction = pinotAggregateFunction;
if (!pinotAggregateFunction.equals("COUNT(*)")) {
pinotAggregateFunction = pinotAggregateFunction.replace("(", "(`").replace(")", "`)");
}
if (!pinotAggregateFunction.contains("(")) {
pinotAggregateFunction = String.format("`%s`", pinotAggregateFunction);
}
if (AGGREGATION_FUNCTIONS.contains(pinotAggregateFunction.substring(0, 3))) {
// For multistage query, we need to explicit hoist the data type to avoid overflow.
String aggFunctionName = pinotAggregateFunction.substring(0, 3);
String replacedPinotAggregationFunction =
pinotAggregateFunction.replace(aggFunctionName + "(", aggFunctionName + "(CAST(");
if ("SUM".equalsIgnoreCase(aggFunctionName)) {
pinotAggregateColumnAndFunction = replacedPinotAggregationFunction.replace(")", " AS BIGINT))");
}
if ("AVG".equalsIgnoreCase(aggFunctionName)) {
pinotAggregateColumnAndFunction = replacedPinotAggregationFunction.replace(")", " AS DOUBLE))");
}
}
pinotAggregateColumnAndFunctions.add(pinotAggregateColumnAndFunction);
}
return pinotAggregateColumnAndFunctions;
}

@Override
public String generateH2Query() {
List<String> h2AggregateColumnAndFunctions = new ArrayList<>();
Expand Down Expand Up @@ -923,7 +957,7 @@ private String generateRandomValue(boolean generateInt) {
private class SingleValueComparisonPredicateGenerator implements PredicateGenerator {

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
String columnValue = pickRandom(_columnToValueList.get(columnName));
String comparisonOperator = pickRandom(COMPARISON_OPERATORS);
return new StringQueryFragment(joinWithSpaces(columnName, comparisonOperator, columnValue),
Expand All @@ -937,7 +971,7 @@ public QueryFragment generatePredicate(String columnName) {
private class SingleValueInPredicateGenerator implements PredicateGenerator {

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
List<String> columnValues = _columnToValueList.get(columnName);

int numValues = Math.min(RANDOM.nextInt(MAX_NUM_IN_CLAUSE_VALUES) + 1, columnValues.size());
Expand All @@ -964,7 +998,7 @@ public QueryFragment generatePredicate(String columnName) {
private class SingleValueBetweenPredicateGenerator implements PredicateGenerator {

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
List<String> columnValues = _columnToValueList.get(columnName);
String leftValue = pickRandom(columnValues);
String rightValue = pickRandom(columnValues);
Expand All @@ -981,7 +1015,7 @@ private class SingleValueRegexPredicateGenerator implements PredicateGenerator {
Random _random = new Random();

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
List<String> columnValues = _columnToValueList.get(columnName);
String value = pickRandom(columnValues);
// do regex only for string type
Expand All @@ -1008,7 +1042,7 @@ public QueryFragment generatePredicate(String columnName) {
private class MultiValueComparisonPredicateGenerator implements PredicateGenerator {

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
String columnValue = pickRandom(_columnToValueList.get(columnName));
String comparisonOperator = pickRandom(COMPARISON_OPERATORS);

Expand All @@ -1024,7 +1058,8 @@ public QueryFragment generatePredicate(String columnName) {
joinWithSpaces(String.format("%s[%d]", columnName, i), comparisonOperator, columnValue));
}

return new StringQueryFragment(joinWithSpaces(columnName, comparisonOperator, columnValue),
return new StringQueryFragment(
joinWithSpaces(generateMultiValueColumn(columnName, useMultistageEngine), comparisonOperator, columnValue),
generateH2QueryConditionPredicate(h2ComparisonClauses));
}
}
Expand All @@ -1036,7 +1071,7 @@ public QueryFragment generatePredicate(String columnName) {
private class MultiValueInPredicateGenerator implements PredicateGenerator {

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
List<String> columnValues = _columnToValueList.get(columnName);

int numValues = Math.min(RANDOM.nextInt(MAX_NUM_IN_CLAUSE_VALUES) + 1, columnValues.size());
Expand All @@ -1052,7 +1087,8 @@ public QueryFragment generatePredicate(String columnName) {
h2InClauses.add(String.format("%s[%d] IN (%s)", columnName, i, inValues));
}

return new StringQueryFragment(String.format("%s IN (%s)", columnName, inValues),
return new StringQueryFragment(
String.format("%s IN (%s)", generateMultiValueColumn(columnName, useMultistageEngine), inValues),
generateH2QueryConditionPredicate(h2InClauses));
}
}
Expand All @@ -1063,7 +1099,7 @@ public QueryFragment generatePredicate(String columnName) {
private class MultiValueBetweenPredicateGenerator implements PredicateGenerator {

@Override
public QueryFragment generatePredicate(String columnName) {
public QueryFragment generatePredicate(String columnName, boolean useMultistageEngine) {
List<String> columnValues = _columnToValueList.get(columnName);
String leftValue = pickRandom(columnValues);
String rightValue = pickRandom(columnValues);
Expand All @@ -1074,13 +1110,24 @@ public QueryFragment generatePredicate(String columnName) {
h2ComparisonClauses.add(String.format("%s[%d] BETWEEN %s AND %s", columnName, i, leftValue, rightValue));
}

return new StringQueryFragment(String.format("%s BETWEEN %s AND %s", columnName, leftValue, rightValue),
return new StringQueryFragment(
String.format("%s BETWEEN %s AND %s", generateMultiValueColumn(columnName, useMultistageEngine), leftValue,
rightValue),
generateH2QueryConditionPredicate(h2ComparisonClauses));
}
}

private String generateMultiValueColumn(String columnName, boolean useMultistageEngine) {
if (useMultistageEngine) {
return String.format("ARRAY_TO_MV(%s)", columnName);
}
return columnName;
}

private static String generateH2QueryConditionPredicate(List<String> conditionList) {
return generateH2QueryConditionPredicate(conditionList, " OR ");
}

private static String generateH2QueryConditionPredicate(List<String> conditionList, String separator) {
return String.format("( %s )", StringUtils.join(conditionList, separator));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,18 @@ protected void cleanupTestTableDataManager(String tableNameWithType) {
}, 600_000L, "Failed to delete table data managers");
}

/**
* Test features supported in V2 Multi-stage Engine.
* - Some V1 features will not be supported.
* - Some V1 features will be added as V2 engine feature development progresses.
* @throws Exception
*/
public void testHardcodedQueriesMultiStage()
throws Exception {
testHardcodedQueriesCommon();
}

/**
* Test hard-coded queries.
* @throws Exception
*/
public void testHardcodedQueries()
throws Exception {
testHardcodedQueriesCommon();
testHardCodedQueriesV1();
if (useMultiStageQueryEngine()) {
testHardcodedQueriesV2();
} else {
testHardCodedQueriesV1();
}
}

/**
Expand Down Expand Up @@ -282,6 +275,29 @@ private void testHardcodedQueriesCommon()
testQuery(query, h2Query);
}

private void testHardcodedQueriesV2()
throws Exception {
String query;
String h2Query;

query =
"SELECT DistanceGroup FROM mytable WHERE \"Month\" BETWEEN 1 AND 1 AND arrayToMV(DivAirportSeqIDs) IN "
+ "(1078102, 1142303, 1530402, 1172102, 1291503) OR SecurityDelay IN (1, 0, 14, -9999) LIMIT 10";
h2Query =
"SELECT DistanceGroup FROM mytable WHERE `Month` BETWEEN 1 AND 1 AND (DivAirportSeqIDs[1] IN (1078102, "
+ "1142303, 1530402, 1172102, 1291503) OR DivAirportSeqIDs[2] IN (1078102, 1142303, 1530402, 1172102, "
+ "1291503) OR DivAirportSeqIDs[3] IN (1078102, 1142303, 1530402, 1172102, 1291503) OR "
+ "DivAirportSeqIDs[4] IN (1078102, 1142303, 1530402, 1172102, 1291503) OR DivAirportSeqIDs[5] IN "
+ "(1078102, 1142303, 1530402, 1172102, 1291503)) OR SecurityDelay IN (1, 0, 14, -9999) LIMIT 10000";
testQuery(query, h2Query);

query = "SELECT MIN(ArrDelayMinutes), AVG(CAST(DestCityMarketID AS DOUBLE)) FROM mytable WHERE DivArrDelay < 196";
h2Query =
"SELECT MIN(CAST(`ArrDelayMinutes` AS DOUBLE)), AVG(CAST(`DestCityMarketID` AS DOUBLE)) FROM mytable WHERE "
+ "`DivArrDelay` < 196";
testQuery(query, h2Query);
}

private void testHardCodedQueriesV1()
throws Exception {
String query;
Expand All @@ -295,17 +311,6 @@ private void testHardCodedQueriesV1()
"SELECT CAST(CAST(ArrTime AS varchar) AS LONG) FROM mytable WHERE DaysSinceEpoch <> 16312 AND Carrier = 'DL' "
+ "ORDER BY ArrTime DESC";
testQuery(query);
// TODO: move to common when multistage support MV columns
query =
"SELECT DistanceGroup FROM mytable WHERE \"Month\" BETWEEN 1 AND 1 AND DivAirportSeqIDs IN (1078102, 1142303,"
+ " 1530402, 1172102, 1291503) OR SecurityDelay IN (1, 0, 14, -9999) LIMIT 10";
h2Query =
"SELECT DistanceGroup FROM mytable WHERE `Month` BETWEEN 1 AND 1 AND (DivAirportSeqIDs[1] IN (1078102, "
+ "1142303, 1530402, 1172102, 1291503) OR DivAirportSeqIDs[2] IN (1078102, 1142303, 1530402, 1172102, "
+ "1291503) OR DivAirportSeqIDs[3] IN (1078102, 1142303, 1530402, 1172102, 1291503) OR "
+ "DivAirportSeqIDs[4] IN (1078102, 1142303, 1530402, 1172102, 1291503) OR DivAirportSeqIDs[5] IN "
+ "(1078102, 1142303, 1530402, 1172102, 1291503)) OR SecurityDelay IN (1, 0, 14, -9999) LIMIT 10000";
testQuery(query, h2Query);

// Non-Standard SQL syntax:
// IN_ID_SET
Expand Down Expand Up @@ -472,8 +477,13 @@ protected void testGeneratedQueries(boolean withMultiValues, boolean useMultista
for (int i = 0; i < numQueriesToGenerate; i++) {
QueryGenerator.Query query = queryGenerator.generateQuery();
if (useMultistageEngine) {
// multistage engine follows standard SQL thus should use H2 query string for testing.
testQuery(query.generateH2Query().replace("`", "\""), query.generateH2Query());
if (withMultiValues) {
// For multistage query with MV columns, we need to use Pinot query string for testing.
testQuery(query.generatePinotQuery().replace("`", "\""), query.generateH2Query());
} else {
// multistage engine follows standard SQL thus should use H2 query string for testing.
testQuery(query.generateH2Query().replace("`", "\""), query.generateH2Query());
}
} else {
testQuery(query.generatePinotQuery(), query.generateH2Query());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.pinot.spi.config.table.TableConfig;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.util.TestUtils;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -95,17 +96,17 @@ protected boolean useMultiStageQueryEngine() {

@Test
@Override
public void testHardcodedQueriesMultiStage()
public void testHardcodedQueries()
throws Exception {
super.testHardcodedQueriesMultiStage();
super.testHardcodedQueries();
}

@Test
@Override
public void testGeneratedQueries()
throws Exception {
// test multistage engine, currently we don't support MV columns.
super.testGeneratedQueries(false, true);
super.testGeneratedQueries(true, true);
}

@Test
Expand Down Expand Up @@ -485,6 +486,25 @@ public void testLiteralOnlyFunc()
assertEquals(results.get(10).asText(), "hello!");
}

@Test
public void testMultiValueColumnGroupBy()
throws Exception {
String pinotQuery = "SELECT count(*), arrayToMV(RandomAirports) FROM mytable "
+ "GROUP BY arrayToMV(RandomAirports)";
JsonNode jsonNode = postQuery(pinotQuery);
Assert.assertEquals(jsonNode.get("resultTable").get("rows").size(), 154);
}

@Test
public void testMultiValueColumnGroupByOrderBy()
throws Exception {
String pinotQuery = "SELECT count(*), arrayToMV(RandomAirports) FROM mytable "
+ "GROUP BY arrayToMV(RandomAirports) "
+ "ORDER BY arrayToMV(RandomAirports) DESC";
JsonNode jsonNode = postQuery(pinotQuery);
Assert.assertEquals(jsonNode.get("resultTable").get("rows").size(), 154);
}

@AfterClass
public void tearDown()
throws Exception {
Expand Down
Loading

0 comments on commit a0ff2e8

Please sign in to comment.