Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL: Support CASE-style filtered count distinct. #5047

Merged
merged 1 commit into from
Nov 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ public static DruidExpression toDruidExpression(
}
} else if (kind == SqlKind.LITERAL) {
// Translate literal.
if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) {
if (RexLiteral.isNullLiteral(rexNode)) {
return DruidExpression.fromExpression(DruidExpression.nullLiteral());
} else if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) {
return DruidExpression.fromExpression(DruidExpression.numberLiteral((Number) RexLiteral.value(rexNode)));
} else if (SqlTypeFamily.INTERVAL_DAY_TIME == sqlTypeName.getFamily()) {
// Calcite represents DAY-TIME intervals in milliseconds.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public boolean matches(final RelOptRuleCall call)
}

for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
if (isNonDistinctOneArgAggregateCall(aggregateCall)
if (isOneArgAggregateCall(aggregateCall)
&& isThreeArgCase(project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList())))) {
return true;
}
Expand All @@ -97,21 +97,13 @@ public void onMatch(RelOptRuleCall call)
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
AggregateCall newCall = null;

if (isNonDistinctOneArgAggregateCall(aggregateCall)) {
if (isOneArgAggregateCall(aggregateCall)) {
final RexNode rexNode = project.getChildExps().get(Iterables.getOnlyElement(aggregateCall.getArgList()));

// Styles supported:
//
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0)
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null)
//
// If the null and non-null args are switched, "flip" is set, which negates the filter.

if (isThreeArgCase(rexNode)) {
final RexCall caseCall = (RexCall) rexNode;

// If one arg is null and the other is not, reverse them and set "flip", which negates the filter.
final boolean flip = RexLiteral.isNullLiteral(caseCall.getOperands().get(1))
&& !RexLiteral.isNullLiteral(caseCall.getOperands().get(2));
final RexNode arg1 = caseCall.getOperands().get(flip ? 2 : 1);
Expand All @@ -126,6 +118,7 @@ public void onMatch(RelOptRuleCall call)
ImmutableList.of(caseCall.getOperands().get(0))
);

// Combine the CASE filter with an honest-to-goodness SQL FILTER, if the latter is present.
if (aggregateCall.filterArg >= 0) {
filter = rexBuilder.makeCall(
booleanType,
Expand All @@ -136,47 +129,72 @@ public void onMatch(RelOptRuleCall call)
filter = filterFromCase;
}

if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT
&& arg1 instanceof RexLiteral
&& !RexLiteral.isNullLiteral(arg1)
&& RexLiteral.isNullLiteral(arg2)) {
// Case C
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
ImmutableList.of(),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
} else if (aggregateCall.getAggregation().getKind() == SqlKind.SUM
&& Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
&& Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
// Case B
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
ImmutableList.of(),
newProjects.size() - 1,
typeFactory.createSqlType(SqlTypeName.BIGINT),
aggregateCall.getName()
);
} else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */
|| (aggregateCall.getAggregation().getKind() == SqlKind.SUM
&& Calcites.isIntLiteral(arg2)
&& RexLiteral.intValue(arg2) == 0) /* Case A2 */) {
newProjects.add(arg1);
newProjects.add(filter);
newCall = AggregateCall.create(
aggregateCall.getAggregation(),
false,
ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
if (aggregateCall.isDistinct()) {
// Just one style supported:
// COUNT(DISTINCT CASE WHEN x = 'foo' THEN y END) => COUNT(DISTINCT y) FILTER(WHERE x = 'foo')

if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && RexLiteral.isNullLiteral(arg2)) {
newProjects.add(arg1);
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
true,
ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
}
} else {
// Four styles supported:
//
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0); must be SUM
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null)

if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT
&& arg1.isA(SqlKind.LITERAL)
&& !RexLiteral.isNullLiteral(arg1)
&& RexLiteral.isNullLiteral(arg2)) {
// Case C
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
ImmutableList.of(),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
} else if (aggregateCall.getAggregation().getKind() == SqlKind.SUM
&& Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
&& Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
// Case B
newProjects.add(filter);
newCall = AggregateCall.create(
SqlStdOperatorTable.COUNT,
false,
ImmutableList.of(),
newProjects.size() - 1,
typeFactory.createSqlType(SqlTypeName.BIGINT),
aggregateCall.getName()
);
} else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */
|| (aggregateCall.getAggregation().getKind() == SqlKind.SUM
&& Calcites.isIntLiteral(arg2)
&& RexLiteral.intValue(arg2) == 0) /* Case A2 */) {
newProjects.add(arg1);
newProjects.add(filter);
newCall = AggregateCall.create(
aggregateCall.getAggregation(),
false,
ImmutableList.of(newProjects.size() - 2),
newProjects.size() - 1,
aggregateCall.getType(),
aggregateCall.getName()
);
}
}
}
}
Expand Down Expand Up @@ -211,9 +229,9 @@ public void onMatch(RelOptRuleCall call)
}
}

private static boolean isNonDistinctOneArgAggregateCall(final AggregateCall aggregateCall)
private static boolean isOneArgAggregateCall(final AggregateCall aggregateCall)
{
return aggregateCall.getArgList().size() == 1 && !aggregateCall.isDistinct();
return aggregateCall.getArgList().size() == 1;
}

private static boolean isThreeArgCase(final RexNode rexNode)
Expand Down
4 changes: 0 additions & 4 deletions sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -63,8 +61,6 @@ public static Aggregation translateAggregateCall(
)
{
final DimFilter filter;
final SqlKind kind = call.getAggregation().getKind();
final SqlTypeName outputType = call.getType().getSqlTypeName();

if (call.filterArg >= 0) {
// AGG(xxx) FILTER(WHERE yyy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ public StringComparator naturalStringComparator(final SimpleExtraction simpleExt
*/
public RelDataType getRelDataType(final RelDataTypeFactory typeFactory)
{
final RelDataTypeFactory.FieldInfoBuilder builder = typeFactory.builder();
final RelDataTypeFactory.Builder builder = typeFactory.builder();
for (final String columnName : columnNames) {
final ValueType columnType = getColumnType(columnName);
final RelDataType type;
Expand Down Expand Up @@ -177,7 +177,10 @@ public RelDataType getRelDataType(final RelDataTypeFactory typeFactory)
break;
case COMPLEX:
// Loses information about exactly what kind of complex column this is.
type = typeFactory.createSqlType(SqlTypeName.OTHER);
type = typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.OTHER),
true
);
break;
default:
throw new ISE("WTF?! valueType[%s] not translatable?", columnType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ public void testDatabaseMetaDataColumns() throws Exception
Pair.of("COLUMN_NAME", "unique_dim1"),
Pair.of("DATA_TYPE", Types.OTHER),
Pair.of("TYPE_NAME", "OTHER"),
Pair.of("IS_NULLABLE", "NO")
Pair.of("IS_NULLABLE", "YES")
)
),
getRows(
Expand Down Expand Up @@ -526,7 +526,7 @@ public void testDatabaseMetaDataColumnsWithSuperuser() throws Exception
Pair.of("COLUMN_NAME", "unique_dim1"),
Pair.of("DATA_TYPE", Types.OTHER),
Pair.of("TYPE_NAME", "OTHER"),
Pair.of("IS_NULLABLE", "NO")
Pair.of("IS_NULLABLE", "YES")
)
),
getRows(
Expand Down
73 changes: 67 additions & 6 deletions sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ public void testInformationSchemaColumnsOnTable() throws Exception
new Object[]{"dim2", "VARCHAR", "YES"},
new Object[]{"m1", "FLOAT", "NO"},
new Object[]{"m2", "DOUBLE", "NO"},
new Object[]{"unique_dim1", "OTHER", "NO"}
new Object[]{"unique_dim1", "OTHER", "YES"}
)
);
}
Expand Down Expand Up @@ -434,12 +434,11 @@ public void testInformationSchemaColumnsOnForbiddenTable() throws Exception
new Object[]{"dim2", "VARCHAR", "YES"},
new Object[]{"m1", "FLOAT", "NO"},
new Object[]{"m2", "DOUBLE", "NO"},
new Object[]{"unique_dim1", "OTHER", "NO"}
new Object[]{"unique_dim1", "OTHER", "YES"}
)
);
}


@Test
public void testInformationSchemaColumnsOnView() throws Exception
{
Expand Down Expand Up @@ -2326,9 +2325,10 @@ public void testFilteredAggregations() throws Exception
+ "SUM(cnt) filter(WHERE dim2 = 'a'), "
+ "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a'), "
+ "SUM(CASE WHEN dim1 <> '1' THEN cnt ELSE 0 END), "
+ "MAX(CASE WHEN dim1 <> '1' THEN cnt END) "
+ "MAX(CASE WHEN dim1 <> '1' THEN cnt END), "
+ "COUNT(DISTINCT CASE WHEN dim1 <> '1' THEN m1 END) "
+ "FROM druid.foo",
ImmutableList.<Query>of(
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(QSS(Filtration.eternity()))
Expand Down Expand Up @@ -2379,13 +2379,23 @@ public void testFilteredAggregations() throws Exception
new FilteredAggregatorFactory(
new LongMaxAggregatorFactory("a9", "cnt"),
NOT(SELECTOR("dim1", "1", null))
),
new FilteredAggregatorFactory(
new CardinalityAggregatorFactory(
"a10",
null,
DIMS(new DefaultDimensionSpec("m1", "m1", ValueType.FLOAT)),
false,
true
),
NOT(SELECTOR("dim1", "1", null))
)
))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1L, 5L, 1L, 2L, 5L, 5L, 2L, 1L, 5L, 1L}
new Object[]{1L, 5L, 1L, 2L, 5L, 5L, 2L, 1L, 5L, 1L, 5L}
)
);
}
Expand Down Expand Up @@ -3426,6 +3436,57 @@ public void testCountDistinct() throws Exception
);
}

@Test
public void testCountDistinctOfCaseWhen() throws Exception
{
testQuery(
"SELECT\n"
+ "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN m1 END),\n"
+ "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN dim1 END),\n"
+ "COUNT(DISTINCT CASE WHEN m1 >= 4 THEN unique_dim1 END)\n"
+ "FROM druid.foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(QSS(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(
AGGS(
new FilteredAggregatorFactory(
new CardinalityAggregatorFactory(
"a0",
null,
ImmutableList.of(new DefaultDimensionSpec("m1", "m1", ValueType.FLOAT)),
false,
true
),
BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC)
),
new FilteredAggregatorFactory(
new CardinalityAggregatorFactory(
"a1",
null,
ImmutableList.of(new DefaultDimensionSpec("dim1", "dim1", ValueType.STRING)),
false,
true
),
BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC)
),
new FilteredAggregatorFactory(
new HyperUniquesAggregatorFactory("a2", "unique_dim1", false, true),
BOUND("m1", "4", null, false, false, null, StringComparators.NUMERIC)
)
)
)
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{3L, 3L, 3L}
)
);
}

@Test
public void testExactCountDistinct() throws Exception
{
Expand Down