Skip to content

Commit

Permalink
feat(static): fail on ROWTIME in projection (#3430)
Browse files Browse the repository at this point in the history
* feat(static): fail on ROWTIME in projection

At the moment static queries do not support returning ROWTIME as this information is not available in the response for KS IQ.

In the future, we _may_ choose to support this by always including ROWTIME in the value of the changelog topic, but this is out of scope for this initial MVP.
  • Loading branch information
big-andy-coates authored Sep 27, 2019
1 parent 5f28ff5 commit 2f27b68
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
import io.confluent.ksql.serde.SerdeOption;
import io.confluent.ksql.util.SchemaUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -56,6 +58,7 @@ public class Analysis {
private Optional<JoinInfo> joinInfo = Optional.empty();
private Optional<Expression> whereExpression = Optional.empty();
private final List<SelectExpression> selectExpressions = new ArrayList<>();
private final Set<ColumnRef> selectColumnRefs = new HashSet<>();
private final List<Expression> groupByExpressions = new ArrayList<>();
private Optional<WindowExpression> windowExpression = Optional.empty();
private Optional<ColumnName> partitionBy = Optional.empty();
Expand All @@ -76,6 +79,10 @@ void addSelectItem(final Expression expression, final ColumnName alias) {
selectExpressions.add(SelectExpression.of(alias, expression));
}

void addSelectColumnRefs(final Collection<ColumnRef> columnRefs) {
selectColumnRefs.addAll(columnRefs);
}

public Optional<Into> getInto() {
return into;
}
Expand All @@ -96,6 +103,10 @@ public List<SelectExpression> getSelectExpressions() {
return Collections.unmodifiableList(selectExpressions);
}

Set<ColumnRef> getSelectColumnRefs() {
return Collections.unmodifiableSet(selectColumnRefs);
}

public List<Expression> getGroupByExpressions() {
return ImmutableList.copyOf(groupByExpressions);
}
Expand Down Expand Up @@ -156,7 +167,7 @@ public List<AliasedDataSource> getFromDataSources() {
return ImmutableList.copyOf(fromDataSources);
}

public SourceSchemas getFromSourceSchemas() {
SourceSchemas getFromSourceSchemas() {
final Map<SourceName, LogicalSchema> schemaBySource = fromDataSources.stream()
.collect(Collectors.toMap(
AliasedDataSource::getAlias,
Expand Down
32 changes: 29 additions & 3 deletions ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.windows.KsqlWindowExpression;
import io.confluent.ksql.metastore.MetaStore;
Expand Down Expand Up @@ -62,6 +63,7 @@
import io.confluent.ksql.serde.ValueFormat;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -511,7 +513,7 @@ protected AstNode visitSelect(final Select node, final Void context) {
visitSelectStar((AllColumns) selectItem);
} else if (selectItem instanceof SingleColumn) {
final SingleColumn column = (SingleColumn) selectItem;
analysis.addSelectItem(column.getExpression(), column.getAlias());
addSelectItem(column.getExpression(), column.getAlias());
} else {
throw new IllegalArgumentException(
"Unsupported SelectItem type: " + selectItem.getClass().getName());
Expand Down Expand Up @@ -562,14 +564,19 @@ private void visitSelectStar(final AllColumns allColumns) {
? source.getAlias().name() + "_"
: "";

for (final Column column : source.getDataSource().getSchema().columns()) {
final LogicalSchema schema = source.getDataSource().getSchema();
for (final Column column : schema.columns()) {

if (staticQuery && schema.isMetaColumn(column.name())) {
continue;
}

final ColumnReferenceExp selectItem = new ColumnReferenceExp(location,
ColumnRef.of(source.getAlias(), column.name()));

final String alias = aliasPrefix + column.name().name();

analysis.addSelectItem(selectItem, ColumnName.of(alias));
addSelectItem(selectItem, ColumnName.of(alias));
}
}
}
Expand Down Expand Up @@ -598,6 +605,25 @@ public void validate() {
+ System.lineSeparator() + KAFKA_VALUE_FORMAT_LIMITATION_DETAILS);
}
}

private void addSelectItem(final Expression exp, final ColumnName columnName) {
final Set<ColumnRef> columnRefs = new HashSet<>();
final TraversalExpressionVisitor<Void> visitor = new TraversalExpressionVisitor<Void>() {
@Override
public Void visitColumnReference(
final ColumnReferenceExp node,
final Void context
) {
columnRefs.add(node.getReference());
return null;
}
};

visitor.process(exp, null);

analysis.addSelectItem(exp, columnName);
analysis.addSelectColumnRefs(columnRefs);
}
}

@FunctionalInterface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.parser.tree.ResultMaterialization;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
Expand Down Expand Up @@ -89,6 +91,12 @@ public class StaticQueryValidator implements QueryValidator {
Rule.of(
analysis -> !analysis.getLimitClause().isPresent(),
"Static queries don't support LIMIT clauses."
),
Rule.of(
analysis -> analysis.getSelectColumnRefs().stream()
.map(ColumnRef::name)
.noneMatch(n -> n.equals(SchemaUtil.ROWTIME_NAME)),
"Static queries don't support ROWTIME in select columns."
)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
package io.confluent.ksql.analyzer;

import static io.confluent.ksql.testutils.AnalysisTestUtil.analyzeQuery;
import static io.confluent.ksql.util.SchemaUtil.ROWTIME_NAME;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
Expand All @@ -35,6 +38,7 @@
import io.confluent.ksql.analyzer.Analyzer.SerdeOptionsSupplier;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.BooleanLiteral;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.Literal;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.plan.SelectExpression;
Expand All @@ -53,6 +57,7 @@
import io.confluent.ksql.parser.tree.Sink;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.planner.plan.JoinNode.JoinType;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.serde.Format;
Expand Down Expand Up @@ -90,6 +95,11 @@
public class AnalyzerFunctionalTest {

private static final Set<SerdeOption> DEFAULT_SERDE_OPTIONS = SerdeOption.none();
private static final SourceName TEST1 = SourceName.of("TEST1");
private static final ColumnName COL0 = ColumnName.of("COL0");
private static final ColumnName COL1 = ColumnName.of("COL1");
private static final ColumnName COL2 = ColumnName.of("COL2");
private static final ColumnName COL3 = ColumnName.of("COL3");

private MutableMetaStore jsonMetaStore;
private MutableMetaStore avroMetaStore;
Expand Down Expand Up @@ -136,17 +146,17 @@ public void testSimpleQueryAnalysis() {
final Analysis analysis = analyzeQuery(simpleQuery, jsonMetaStore);
assertEquals("FROM was not analyzed correctly.",
analysis.getFromDataSources().get(0).getDataSource().getName(),
SourceName.of("TEST1"));
TEST1);
assertThat(analysis.getWhereExpression().get().toString(), is("(TEST1.COL0 > 100)"));

final List<SelectExpression> selects = analysis.getSelectExpressions();
assertThat(selects.get(0).getExpression().toString(), is("TEST1.COL0"));
assertThat(selects.get(1).getExpression().toString(), is("TEST1.COL2"));
assertThat(selects.get(2).getExpression().toString(), is("TEST1.COL3"));

assertThat(selects.get(0).getName(), is(ColumnName.of("COL0")));
assertThat(selects.get(1).getName(), is(ColumnName.of("COL2")));
assertThat(selects.get(2).getName(), is(ColumnName.of("COL3")));
assertThat(selects.get(0).getName(), is(COL0));
assertThat(selects.get(1).getName(), is(COL2));
assertThat(selects.get(2).getName(), is(COL3));
}

@Test
Expand Down Expand Up @@ -202,7 +212,7 @@ public void testBooleanExpressionAnalysis() {
final Analysis analysis = analyzeQuery(queryStr, jsonMetaStore);

assertEquals("FROM was not analyzed correctly.",
analysis.getFromDataSources().get(0).getDataSource().getName(), SourceName.of("TEST1"));
analysis.getFromDataSources().get(0).getDataSource().getName(), TEST1);

final List<SelectExpression> selects = analysis.getSelectExpressions();
assertThat(selects.get(0).getExpression().toString(), is("(TEST1.COL0 = 10)"));
Expand All @@ -215,7 +225,7 @@ public void testFilterAnalysis() {
final String queryStr = "SELECT col0 = 10, col2, col3 > col1 FROM test1 WHERE col0 > 20 EMIT CHANGES;";
final Analysis analysis = analyzeQuery(queryStr, jsonMetaStore);

assertThat(analysis.getFromDataSources().get(0).getDataSource().getName(), is(SourceName.of("TEST1")));
assertThat(analysis.getFromDataSources().get(0).getDataSource().getName(), is(TEST1));

final List<SelectExpression> selects = analysis.getSelectExpressions();
assertThat(selects.get(0).getExpression().toString(), is("(TEST1.COL0 = 10)"));
Expand Down Expand Up @@ -450,6 +460,50 @@ public void shouldThrowOnJoinIfKafkaFormat() {
analyzer.analyze(query, Optional.of(sink));
}

@Test
public void shouldCaptureProjectionColumnRefs() {
// Given:
query = parseSingle("Select COL0, COL0 + COL1, SUBSTRING(COL2, 1) from TEST1;");

// When:
final Analysis analysis = analyzer.analyze(query, Optional.empty());

// Then:
assertThat(analysis.getSelectColumnRefs(), containsInAnyOrder(
ColumnRef.of(TEST1, COL0),
ColumnRef.of(TEST1, COL1),
ColumnRef.of(TEST1, COL2)
));
}

@Test
public void shouldIncludeMetaColumnsForSelectStarOnContinuousQueries() {
// Given:
query = parseSingle("Select * from TEST1 EMIT CHANGES;");

// When:
final Analysis analysis = analyzer.analyze(query, Optional.empty());

// Then:
assertThat(analysis.getSelectExpressions(), hasItem(
SelectExpression.of(ROWTIME_NAME, new ColumnReferenceExp(ColumnRef.of(TEST1, ROWTIME_NAME)))
));
}

@Test
public void shouldNotIncludeMetaColumnsForSelectStartOnStaticQueries() {
// Given:
query = parseSingle("Select * from TEST1;");

// When:
final Analysis analysis = analyzer.analyze(query, Optional.empty());

// Then:
assertThat(analysis.getSelectExpressions(), not(hasItem(
SelectExpression.of(ROWTIME_NAME, new ColumnReferenceExp(ColumnRef.of(TEST1, ROWTIME_NAME)))
)));
}

@SuppressWarnings("unchecked")
private <T extends Statement> T parseSingle(final String simpleQuery) {
return (T) Iterables.getOnlyElement(parse(simpleQuery, jsonMetaStore));
Expand Down Expand Up @@ -478,7 +532,7 @@ private void buildProps() {

private void registerKafkaSource() {
final LogicalSchema schema = LogicalSchema.builder()
.valueColumn(ColumnName.of("COL0"), SqlTypes.BIGINT)
.valueColumn(COL0, SqlTypes.BIGINT)
.build();

final KsqlTopic topic = new KsqlTopic(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.confluent.ksql.analyzer.Analysis.Into;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.parser.tree.ResultMaterialization;
import io.confluent.ksql.parser.tree.WindowExpression;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.Optional;
import java.util.OptionalInt;
import org.junit.Before;
Expand Down Expand Up @@ -109,7 +112,7 @@ public void shouldThrowOnStaticQueryThatIsWindowed() {
}

@Test
public void shouldThrowOnStaticQueryThatHasGroupBy() {
public void shouldThrowOnGroupBy() {
// Given:
when(analysis.getGroupByExpressions()).thenReturn(ImmutableList.of(AN_EXPRESSION));

Expand All @@ -122,7 +125,7 @@ public void shouldThrowOnStaticQueryThatHasGroupBy() {
}

@Test
public void shouldThrowOnStaticQueryThatHasPartitionBy() {
public void shouldThrowOnPartitionBy() {
// Given:
when(analysis.getPartitionBy()).thenReturn(Optional.of(ColumnName.of("Something")));

Expand All @@ -135,7 +138,7 @@ public void shouldThrowOnStaticQueryThatHasPartitionBy() {
}

@Test
public void shouldThrowOnStaticQueryThatHasHavingClause() {
public void shouldThrowOnHavingClause() {
// Given:
when(analysis.getHavingExpression()).thenReturn(Optional.of(AN_EXPRESSION));

Expand All @@ -148,7 +151,7 @@ public void shouldThrowOnStaticQueryThatHasHavingClause() {
}

@Test
public void shouldThrowOnStaticQueryThatHasLimitClause() {
public void shouldThrowOnLimitClause() {
// Given:
when(analysis.getLimitClause()).thenReturn(OptionalInt.of(1));

Expand All @@ -159,4 +162,18 @@ public void shouldThrowOnStaticQueryThatHasLimitClause() {
// When:
validator.validate(analysis);
}

@Test
public void shouldThrowOnRowTimeInProjection() {
// Given:
when(analysis.getSelectColumnRefs())
.thenReturn(ImmutableSet.of(ColumnRef.of(SchemaUtil.ROWTIME_NAME)));

// Then:
expectedException.expect(KsqlException.class);
expectedException.expectMessage("Static queries don't support ROWTIME in select columns.");

// When:
validator.validate(analysis);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,32 @@
}
]
},
{
"name": "non-windowed projection WITH ROWTIME",
"statements": [
"CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT GROUP BY ROWKEY;",
"SELECT ROWTIME + 10, COUNT FROM AGGREGATE WHERE ROWKEY='10';"
],
"expectedError": {
"type": "io.confluent.ksql.rest.entity.KsqlStatementErrorMessage",
"message": "Static queries don't support ROWTIME in select columns.",
"status": 400
}
},
{
"name": "windowed with projection with ROWTIME",
"statements": [
"CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ROWKEY;",
"SELECT COUNT, ROWTIME + 10 FROM AGGREGATE WHERE ROWKEY='10' AND WindowStart=12000;"
],
"expectedError": {
"type": "io.confluent.ksql.rest.entity.KsqlStatementErrorMessage",
"message": "Static queries don't support ROWTIME in select columns.",
"status": 400
}
},
{
"name": "text datetime window bounds",
"enabled": false,
Expand Down

0 comments on commit 2f27b68

Please sign in to comment.