diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java index 734d9e38d81e..14f844119639 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java @@ -221,12 +221,6 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { .getKsqlTopic() .getValueFormat(); - final Serde genericRowSerde = builder.buildValueSerde( - valueFormat.getFormatInfo(), - PhysicalSchema.from(prepareSchema, SerdeOption.none()), - groupByContext.getQueryContext() - ); - final List internalGroupByColumns = internalSchema.resolveGroupByExpressions( getGroupByExpressions(), aggregateArgExpanded, @@ -235,9 +229,9 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { final SchemaKGroupedStream schemaKGroupedStream = aggregateArgExpanded.groupBy( valueFormat, - genericRowSerde, internalGroupByColumns, - groupByContext + groupByContext, + builder ); // Aggregate computations diff --git a/ksql-engine/src/main/java/io/confluent/ksql/streams/StreamsFactories.java b/ksql-engine/src/main/java/io/confluent/ksql/streams/StreamsFactories.java index 1d31351e5826..6616f1cabd46 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/streams/StreamsFactories.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/streams/StreamsFactories.java @@ -15,6 +15,7 @@ package io.confluent.ksql.streams; +import io.confluent.ksql.execution.streams.GroupedFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; import io.confluent.ksql.util.KsqlConfig; import java.util.Objects; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java index c06944acaae8..8303bd03b01e 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java @@ -23,8 +23,6 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; -import io.confluent.ksql.execution.codegen.CodeGenRunner; -import io.confluent.ksql.execution.codegen.ExpressionMetadata; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.context.QueryLoggerUtil; import io.confluent.ksql.execution.ddl.commands.KsqlTopic; @@ -37,12 +35,15 @@ import io.confluent.ksql.execution.plan.LogicalSchemaWithMetaAndKeyFields; import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.execution.plan.StreamFilter; +import io.confluent.ksql.execution.plan.StreamGroupBy; +import io.confluent.ksql.execution.plan.StreamGroupByKey; import io.confluent.ksql.execution.plan.StreamMapValues; import io.confluent.ksql.execution.plan.StreamSelectKey; import io.confluent.ksql.execution.plan.StreamSource; import io.confluent.ksql.execution.plan.StreamToTable; import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.execution.streams.StreamFilterBuilder; +import io.confluent.ksql.execution.streams.StreamGroupByBuilder; import io.confluent.ksql.execution.streams.StreamMapValuesBuilder; import io.confluent.ksql.execution.streams.StreamSelectKeyBuilder; import io.confluent.ksql.execution.streams.StreamSourceBuilder; @@ -72,12 +73,11 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.connect.data.Struct; import org.apache.kafka.streams.Topology.AutoOffsetReset; -import org.apache.kafka.streams.kstream.Grouped; import org.apache.kafka.streams.kstream.JoinWindows; -import org.apache.kafka.streams.kstream.KGroupedStream; import org.apache.kafka.streams.kstream.KStream; import org.apache.kafka.streams.kstream.KTable; import org.apache.kafka.streams.kstream.Produced; @@ -92,6 +92,8 @@ public class SchemaKStream { private static final FormatOptions FORMAT_OPTIONS = FormatOptions.of(IdentifierUtil::needsQuotes); + static final String GROUP_BY_COLUMN_SEPARATOR = "|+|"; + public enum Type { SOURCE, PROJECT, FILTER, AGGREGATE, SINK, REKEY, JOIN } final KStream kstream; @@ -777,73 +779,38 @@ private boolean rekeyRequired(final List groupByExpressions) { public SchemaKGroupedStream groupBy( final ValueFormat valueFormat, - final Serde valSerde, final List groupByExpressions, - final QueryContext.Stacker contextStacker + final QueryContext.Stacker contextStacker, + final KsqlQueryBuilder queryBuilder ) { final boolean rekey = rekeyRequired(groupByExpressions); final KeyFormat rekeyedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo()); if (!rekey) { - final Grouped grouped = streamsFactories.getGroupedFactory() - .create( - StreamsUtil.buildOpName(contextStacker.getQueryContext()), - keySerde, - valSerde - ); - - final KGroupedStream kgroupedStream = kstream.groupByKey(grouped); - - final KeySerde structKeySerde = getGroupByKeyKeySerde(); - - final ExecutionStep> step = - ExecutionStepFactory.streamGroupBy( - contextStacker, - sourceStep, - Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()), - groupByExpressions - ); - return new SchemaKGroupedStream( - kgroupedStream, - step, - keyFormat, - structKeySerde, - keyField, - Collections.singletonList(this), - ksqlConfig, - functionRegistry - ); + return groupByKey(rekeyedKeyFormat, valueFormat, contextStacker, queryBuilder); } - final GroupBy groupBy = new GroupBy(groupByExpressions); - final KeySerde groupedKeySerde = keySerde .rebind(StructKeyUtil.ROWKEY_SERIALIZED_SCHEMA); - final Grouped grouped = streamsFactories.getGroupedFactory() - .create( - StreamsUtil.buildOpName(contextStacker.getQueryContext()), - groupedKeySerde, - valSerde - ); - - final KGroupedStream kgroupedStream = kstream - .filter((key, value) -> value != null) - .groupBy(groupBy.mapper, grouped); - + final String aggregateKeyName = groupedKeyNameFor(groupByExpressions); final LegacyField legacyKeyField = LegacyField - .notInSchema(groupBy.aggregateKeyName, SqlTypes.STRING); - - final Optional newKeyCol = getSchema().findValueColumn(groupBy.aggregateKeyName) + .notInSchema(aggregateKeyName, SqlTypes.STRING); + final Optional newKeyCol = getSchema().findValueColumn(aggregateKeyName) .map(Column::name); - final ExecutionStep> source = - ExecutionStepFactory.streamGroupBy( - contextStacker, - sourceStep, - Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()), - groupByExpressions - ); + + final StreamGroupBy source = ExecutionStepFactory.streamGroupBy( + contextStacker, + sourceStep, + Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()), + groupByExpressions + ); return new SchemaKGroupedStream( - kgroupedStream, + StreamGroupByBuilder.build( + kstream, + source, + queryBuilder, + streamsFactories.getGroupedFactory() + ), source, rekeyedKeyFormat, groupedKeySerde, @@ -854,6 +821,37 @@ public SchemaKGroupedStream groupBy( ); } + @SuppressWarnings("unchecked") + private SchemaKGroupedStream groupByKey( + final KeyFormat rekeyedKeyFormat, + final ValueFormat valueFormat, + final QueryContext.Stacker contextStacker, + final KsqlQueryBuilder queryBuilder + ) { + final KeySerde structKeySerde = getGroupByKeyKeySerde(); + final StreamGroupByKey step = + ExecutionStepFactory.streamGroupByKey( + contextStacker, + (ExecutionStep) sourceStep, + Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()) + ); + return new SchemaKGroupedStream( + StreamGroupByBuilder.build( + (KStream) kstream, + step, + queryBuilder, + streamsFactories.getGroupedFactory() + ), + step, + keyFormat, + structKeySerde, + keyField, + Collections.singletonList(this), + ksqlConfig, + functionRegistry + ); + } + @SuppressWarnings("unchecked") private KeySerde getGroupByKeyKeySerde() { if (keySerde.isWindowed()) { @@ -920,18 +918,10 @@ public FunctionRegistry getFunctionRegistry() { return functionRegistry; } - class GroupBy { - - final String aggregateKeyName; - final GroupByMapper mapper; - - GroupBy(final List expressions) { - final List groupBy = CodeGenRunner.compileExpressions( - expressions.stream(), "Group By", getSchema(), ksqlConfig, functionRegistry); - - this.mapper = new GroupByMapper(groupBy); - this.aggregateKeyName = GroupByMapper.keyNameFor(expressions); - } + String groupedKeyNameFor(final List groupByExpressions) { + return groupByExpressions.stream() + .map(Expression::toString) + .collect(Collectors.joining(GROUP_BY_COLUMN_SEPARATOR)); } protected static class KsqlValueJoiner diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKTable.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKTable.java index 6ee7d2763760..037c5e7edb12 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKTable.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKTable.java @@ -25,10 +25,11 @@ import io.confluent.ksql.execution.plan.JoinType; import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.execution.plan.TableFilter; +import io.confluent.ksql.execution.plan.TableGroupBy; import io.confluent.ksql.execution.plan.TableMapValues; import io.confluent.ksql.execution.streams.ExecutionStepFactory; -import io.confluent.ksql.execution.streams.StreamsUtil; import io.confluent.ksql.execution.streams.TableFilterBuilder; +import io.confluent.ksql.execution.streams.TableGroupByBuilder; import io.confluent.ksql.execution.streams.TableMapValuesBuilder; import io.confluent.ksql.execution.util.StructKeyUtil; import io.confluent.ksql.function.FunctionRegistry; @@ -51,9 +52,6 @@ import java.util.Set; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.KeyValue; -import org.apache.kafka.streams.kstream.Grouped; -import org.apache.kafka.streams.kstream.KGroupedTable; import org.apache.kafka.streams.kstream.KStream; import org.apache.kafka.streams.kstream.KTable; import org.apache.kafka.streams.kstream.Produced; @@ -231,48 +229,37 @@ public ExecutionStep> getSourceTableStep() { } @Override + @SuppressWarnings("unchecked") public SchemaKGroupedStream groupBy( final ValueFormat valueFormat, - final Serde valSerde, final List groupByExpressions, - final QueryContext.Stacker contextStacker + final QueryContext.Stacker contextStacker, + final KsqlQueryBuilder queryBuilder ) { - final GroupBy groupBy = new GroupBy(groupByExpressions); final KeyFormat groupedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo()); final KeySerde groupedKeySerde = keySerde .rebind(StructKeyUtil.ROWKEY_SERIALIZED_SCHEMA); - final Grouped grouped = streamsFactories.getGroupedFactory() - .create( - StreamsUtil.buildOpName(contextStacker.getQueryContext()), - groupedKeySerde, - valSerde - ); - - final KGroupedTable kgroupedTable = ktable - .filter((key, value) -> value != null) - .groupBy( - (key, value) -> new KeyValue<>(groupBy.mapper.apply(key, value), value), - grouped - ); + final String aggregateKeyName = groupedKeyNameFor(groupByExpressions); + final LegacyField legacyKeyField = LegacyField.notInSchema(aggregateKeyName, SqlTypes.STRING); + final Optional newKeyField = + getSchema().findValueColumn(aggregateKeyName).map(Column::fullName); - final LegacyField legacyKeyField = LegacyField - .notInSchema(groupBy.aggregateKeyName, SqlTypes.STRING); - - final Optional newKeyField = getSchema().findValueColumn(groupBy.aggregateKeyName) - .map(Column::fullName); - - final ExecutionStep> step = - ExecutionStepFactory.tableGroupBy( - contextStacker, - sourceTableStep, - Formats.of(groupedKeyFormat, valueFormat, SerdeOption.none()), - groupByExpressions - ); + final TableGroupBy step = ExecutionStepFactory.tableGroupBy( + contextStacker, + sourceTableStep, + Formats.of(groupedKeyFormat, valueFormat, SerdeOption.none()), + groupByExpressions + ); return new SchemaKGroupedTable( - kgroupedTable, + TableGroupByBuilder.build( + ktable, + step, + queryBuilder, + streamsFactories.getGroupedFactory() + ), step, groupedKeyFormat, groupedKeySerde, diff --git a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/AggregateNodeTest.java b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/AggregateNodeTest.java index 0d98075fe00e..4bc93315f578 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/AggregateNodeTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/AggregateNodeTest.java @@ -52,6 +52,7 @@ import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.PersistenceSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.KeySerde; import io.confluent.ksql.serde.WindowInfo; import io.confluent.ksql.structured.SchemaKStream; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlBareOutputNodeTest.java b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlBareOutputNodeTest.java index 8a00da999e71..5ffd05022578 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlBareOutputNodeTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlBareOutputNodeTest.java @@ -35,6 +35,7 @@ import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.KeySerde; import io.confluent.ksql.structured.SchemaKStream; import io.confluent.ksql.testutils.AnalysisTestUtil; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/streams/GroupedFactoryTest.java b/ksql-engine/src/test/java/io/confluent/ksql/streams/GroupedFactoryTest.java index 8654470bb195..e72bf4183615 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/streams/GroupedFactoryTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/streams/GroupedFactoryTest.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.streams.GroupedFactory; import io.confluent.ksql.util.KsqlConfig; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.streams.StreamsConfig; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java index 4b2bb972c568..193597cf6477 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java @@ -32,6 +32,7 @@ import com.google.common.collect.ImmutableMap; import io.confluent.kafka.schemaregistry.client.MockSchemaRegistryClient; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; @@ -137,6 +138,8 @@ public class SchemaKGroupedTableTest { private KsqlAggregateFunction otherFunc; @Mock private TableAggregationFunction tableFunc; + @Mock + private KsqlQueryBuilder queryBuilder; private KTable kTable; private KsqlTable ksqlTable; @@ -160,6 +163,9 @@ public void init() { Consumed.with(Serdes.String(), rowSerde) ); + when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); + when(queryBuilder.getKsqlConfig()).thenReturn(ksqlConfig); + when(aggregateSchema.findValueColumn("GROUPING_COLUMN")) .thenReturn(Optional.of(Column.of("GROUPING_COLUMN", SqlTypes.STRING))); @@ -171,6 +177,7 @@ private ExecutionStep buildSourceTableStep(final LogicalSchema schema) { when(step.getProperties()).thenReturn( new DefaultExecutionStepProperties(schema, queryContext.getQueryContext()) ); + when(step.getSchema()).thenReturn(schema); return step; } @@ -198,18 +205,8 @@ private SchemaKGroupedTable buildSchemaKGroupedTableFromQuery( .map(c -> new QualifiedNameReference(QualifiedName.of("TEST1", c))) .collect(Collectors.toList()); - final Serde rowSerde = GenericRowSerDe.from( - FormatInfo.of(Format.JSON, Optional.empty()), - PersistenceSchema - .from(initialSchemaKTable.getSchema().withoutAlias().valueConnectSchema(), false), - null, - () -> null, - "test", - processingLogContext - ); - final SchemaKGroupedStream groupedSchemaKTable = initialSchemaKTable.groupBy( - valueFormat, rowSerde, groupByExpressions, queryContext); + valueFormat, groupByExpressions, queryContext, queryBuilder); Assert.assertThat(groupedSchemaKTable, instanceOf(SchemaKGroupedTable.class)); return (SchemaKGroupedTable)groupedSchemaKTable; } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java index 5c4bee374f87..cfa08f98ff6c 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java @@ -71,6 +71,7 @@ import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.PersistenceSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatInfo; @@ -79,7 +80,7 @@ import io.confluent.ksql.serde.KeySerde; import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.serde.ValueFormat; -import io.confluent.ksql.streams.GroupedFactory; +import io.confluent.ksql.execution.streams.GroupedFactory; import io.confluent.ksql.streams.JoinedFactory; import io.confluent.ksql.streams.StreamsFactories; import io.confluent.ksql.structured.SchemaKStream.Type; @@ -148,17 +149,15 @@ public class SchemaKStreamTest { private Serde leftSerde; private Serde rightSerde; private LogicalSchema joinSchema; - private Serde rowSerde; - private KeyFormat keyFormat = KeyFormat.nonWindowed(FormatInfo.of(Format.JSON)); - private ValueFormat valueFormat = ValueFormat.of(FormatInfo.of(Format.JSON)); - private ValueFormat rightFormat = ValueFormat.of(FormatInfo.of(Format.DELIMITED)); + private final KeyFormat keyFormat = KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA)); + private final ValueFormat valueFormat = ValueFormat.of(FormatInfo.of(Format.JSON)); + private final ValueFormat rightFormat = ValueFormat.of(FormatInfo.of(Format.DELIMITED)); private final LogicalSchema simpleSchema = LogicalSchema.builder() .valueColumn("key", SqlTypes.STRING) .valueColumn("val", SqlTypes.BIGINT) .build(); private final QueryContext.Stacker queryContext = new QueryContext.Stacker(new QueryId("query")).push("node"); - private final QueryContext parentContext = queryContext.push("parent").getQueryContext(); private final QueryContext.Stacker childContextStacker = queryContext.push("child"); private final ProcessingLogContext processingLogContext = ProcessingLogContext.create(); @@ -177,8 +176,6 @@ public class SchemaKStreamTest { @Mock private KeySerde reboundKeySerde; @Mock - private KeySerde windowedKeySerde; - @Mock private ExecutionStepProperties tableSourceProperties; @Mock private ExecutionStep tableSourceStep; @@ -604,9 +601,9 @@ public void testGroupByKey() { // When: final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( valueFormat, - rowSerde, groupBy, - childContextStacker); + childContextStacker, + queryBuilder); // Then: assertThat(groupedSchemaKStream.getKeyField().name(), is(Optional.of("TEST1.COL0"))); @@ -614,7 +611,7 @@ public void testGroupByKey() { } @Test - public void shouldBuildStepForGroupBy() { + public void shouldBuildStepForGroupByKey() { // Given: givenInitialKStreamOf("SELECT col0, col1 FROM test1 WHERE col0 > 100 EMIT CHANGES;"); final List groupBy = Collections.singletonList( @@ -624,9 +621,38 @@ public void shouldBuildStepForGroupBy() { // When: final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( valueFormat, - rowSerde, groupBy, - childContextStacker); + childContextStacker, + queryBuilder); + + // Then: + final KeyFormat expectedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo()); + assertThat( + groupedSchemaKStream.getSourceStep(), + equalTo( + ExecutionStepFactory.streamGroupByKey( + childContextStacker, + initialSchemaKStream.getSourceStep(), + Formats.of(expectedKeyFormat, valueFormat, SerdeOption.none()) + ) + ) + ); + } + + @Test + public void shouldBuildStepForGroupBy() { + // Given: + givenInitialKStreamOf("SELECT col0, col1 FROM test1 WHERE col0 > 100 EMIT CHANGES;"); + final List groupBy = Collections.singletonList( + new QualifiedNameReference(QualifiedName.of("TEST1", "COL1")) + ); + + // When: + final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( + valueFormat, + groupBy, + childContextStacker, + queryBuilder); // Then: final KeyFormat expectedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo()); @@ -656,9 +682,9 @@ public void testGroupByMultipleColumns() { // When: final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( valueFormat, - rowSerde, groupBy, - childContextStacker); + childContextStacker, + queryBuilder); // Then: assertThat(groupedSchemaKStream.getKeyField().name(), is(Optional.empty())); @@ -675,9 +701,9 @@ public void testGroupByMoreComplexExpression() { // When: final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( valueFormat, - rowSerde, ImmutableList.of(groupBy), - childContextStacker); + childContextStacker, + queryBuilder); // Then: assertThat(groupedSchemaKStream.getKeyField().name(), is(Optional.empty())); @@ -695,13 +721,15 @@ public void shouldUseFactoryForGroupedWithoutRekey() { )); final List groupByExpressions = Collections.singletonList(keyExpression); givenInitialSchemaKStreamUsesMocks(); + when(queryBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(leftSerde); // When: initialSchemaKStream.groupBy( valueFormat, - leftSerde, groupByExpressions, - childContextStacker); + childContextStacker, + queryBuilder); // Then: verify(mockGroupedFactory).create( @@ -710,6 +738,17 @@ public void shouldUseFactoryForGroupedWithoutRekey() { same(leftSerde) ); verify(mockKStream).groupByKey(same(grouped)); + final LogicalSchema logicalSchema = ksqlStream.getSchema().withAlias(ksqlStream.getName()); + verify(queryBuilder).buildKeySerde( + FormatInfo.of(Format.KAFKA), + PhysicalSchema.from(logicalSchema, SerdeOption.none()), + childContextStacker.getQueryContext() + ); + verify(queryBuilder).buildValueSerde( + valueFormat.getFormatInfo(), + PhysicalSchema.from(logicalSchema, SerdeOption.none()), + childContextStacker.getQueryContext() + ); } @Test @@ -725,13 +764,15 @@ public void shouldUseFactoryForGrouped() { new QualifiedNameReference(QualifiedName.of(ksqlStream.getName(), "COL1")); final List groupByExpressions = Arrays.asList(col1Expression, col0Expression); givenInitialSchemaKStreamUsesMocks(); + when(queryBuilder.buildKeySerde(any(), any(), any())).thenReturn(reboundKeySerde); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(leftSerde); // When: initialSchemaKStream.groupBy( valueFormat, - leftSerde, groupByExpressions, - childContextStacker); + childContextStacker, + queryBuilder); // Then: verify(mockGroupedFactory).create( @@ -739,6 +780,17 @@ public void shouldUseFactoryForGrouped() { same(reboundKeySerde), same(leftSerde)); verify(mockKStream).groupBy(any(KeyValueMapper.class), same(grouped)); + final LogicalSchema logicalSchema = ksqlStream.getSchema().withAlias(ksqlStream.getName()); + verify(queryBuilder).buildKeySerde( + FormatInfo.of(Format.KAFKA), + PhysicalSchema.from(logicalSchema, SerdeOption.none()), + childContextStacker.getQueryContext() + ); + verify(queryBuilder).buildValueSerde( + valueFormat.getFormatInfo(), + PhysicalSchema.from(logicalSchema, SerdeOption.none()), + childContextStacker.getQueryContext() + ); } @Test @@ -1249,6 +1301,7 @@ private void verifyCreateJoined(final Serde rightSerde) { private void givenSourcePropertiesWithSchema(final LogicalSchema schema) { reset(sourceProperties); + when(sourceStep.getSchema()).thenReturn(schema); when(sourceProperties.getSchema()).thenReturn(schema); when(sourceProperties.withQueryContext(any())).thenAnswer( i -> new DefaultExecutionStepProperties(schema, (QueryContext) i.getArguments()[0]) @@ -1337,7 +1390,6 @@ private PlanNode givenInitialKStreamOf(final String selectQuery) { selectQuery, metaStore ); - givenSourcePropertiesWithSchema(logicalPlan.getTheSourceNode().getSchema()); initialSchemaKStream = new SchemaKStream( kStream, @@ -1351,14 +1403,6 @@ private PlanNode givenInitialKStreamOf(final String selectQuery) { functionRegistry ); - rowSerde = GenericRowSerDe.from( - FormatInfo.of(Format.JSON, Optional.empty()), - PersistenceSchema.from(initialSchemaKStream.getSchema().valueConnectSchema(), false), - null, - () -> null, - "test", - processingLogContext); - return logicalPlan; } } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKTableTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKTableTest.java index e99e0d9d217a..58ba569d158b 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKTableTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKTableTest.java @@ -77,7 +77,7 @@ import io.confluent.ksql.serde.KeySerde; import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.serde.ValueFormat; -import io.confluent.ksql.streams.GroupedFactory; +import io.confluent.ksql.execution.streams.GroupedFactory; import io.confluent.ksql.streams.JoinedFactory; import io.confluent.ksql.streams.StreamsFactories; import io.confluent.ksql.structured.SchemaKStream.Type; @@ -141,7 +141,6 @@ public class SchemaKTableTest { = new QueryContext.Stacker(new QueryId("query")).push("node"); private final QueryContext.Stacker childContextStacker = queryContext.push("child"); private final ProcessingLogContext processingLogContext = ProcessingLogContext.create(); - private Serde rowSerde; private static final Expression TEST_2_COL_1 = new QualifiedNameReference(QualifiedName.of("TEST2", "COL1")); private static final Expression TEST_2_COL_2 = @@ -191,6 +190,7 @@ private ExecutionStep buildSourceStep(final LogicalSchema schema) { final ExecutionStep sourceStep = Mockito.mock(ExecutionStep.class); when(sourceStep.getProperties()).thenReturn( new DefaultExecutionStepProperties(schema, queryContext.getQueryContext())); + when(sourceStep.getSchema()).thenReturn(schema); return sourceStep; } @@ -436,15 +436,14 @@ public void testGroupBy() { final String selectQuery = "SELECT col0, col1, col2 FROM test2 EMIT CHANGES;"; final PlanNode logicalPlan = buildLogicalPlan(selectQuery); initialSchemaKTable = buildSchemaKTableFromPlan(logicalPlan); - final Serde rowSerde = mock(Serde.class); final List groupByExpressions = Arrays.asList(TEST_2_COL_2, TEST_2_COL_1); // When: final SchemaKGroupedStream groupedSchemaKTable = initialSchemaKTable.groupBy( valueFormat, - rowSerde, groupByExpressions, - childContextStacker); + childContextStacker, + queryBuilder); // Then: assertThat(groupedSchemaKTable, instanceOf(SchemaKGroupedTable.class)); @@ -459,15 +458,14 @@ public void shouldBuildStepForGroupBy() { final String selectQuery = "SELECT col0, col1, col2 FROM test2 EMIT CHANGES;"; final PlanNode logicalPlan = buildLogicalPlan(selectQuery); initialSchemaKTable = buildSchemaKTableFromPlan(logicalPlan); - final Serde rowSerde = mock(Serde.class); final List groupByExpressions = Arrays.asList(TEST_2_COL_2, TEST_2_COL_1); // When: final SchemaKGroupedStream groupedSchemaKTable = initialSchemaKTable.groupBy( valueFormat, - rowSerde, groupByExpressions, - childContextStacker); + childContextStacker, + queryBuilder); // Then: assertThat( @@ -488,6 +486,7 @@ public void shouldUseOpNameForGrouped() { // Given: final Serde valSerde = getRowSerde(ksqlTable.getKsqlTopic(), ksqlTable.getSchema().valueConnectSchema()); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(valSerde); expect( groupedFactory.create( eq(StreamsUtil.buildOpName(childContextStacker.getQueryContext())), @@ -503,7 +502,7 @@ public void shouldUseOpNameForGrouped() { final SchemaKTable schemaKTable = buildSchemaKTable(ksqlTable, mockKTable, groupedFactory); // When: - schemaKTable.groupBy(valueFormat, valSerde, groupByExpressions, childContextStacker); + schemaKTable.groupBy(valueFormat, groupByExpressions, childContextStacker, queryBuilder); // Then: verify(mockKTable, groupedFactory); @@ -538,16 +537,9 @@ public void shouldGroupKeysCorrectly() { ); final List groupByExpressions = Arrays.asList(TEST_2_COL_2, TEST_2_COL_1); - final Serde rowSerde = GenericRowSerDe.from( - FormatInfo.of(Format.JSON, Optional.empty()), - PersistenceSchema.from(initialSchemaKTable.getSchema().valueConnectSchema(), false), - null, - () -> null, - "test", - processingLogContext); // Call groupBy and extract the captured mapper - initialSchemaKTable.groupBy(valueFormat, rowSerde, groupByExpressions, childContextStacker); + initialSchemaKTable.groupBy(valueFormat, groupByExpressions, childContextStacker, queryBuilder); verify(mockKTable, mockKGroupedTable); final KeyValueMapper keySelector = capturedKeySelector.getValue(); final GenericRow value = new GenericRow(Arrays.asList("key", 0, 100, "foo", "bar")); @@ -793,7 +785,7 @@ public void shouldSetKeyOnGroupBySingleExpressionThatIsInProjection() { // When: final SchemaKGroupedStream result = initialSchemaKTable - .groupBy(valueFormat, rowSerde, groupByExprs, childContextStacker); + .groupBy(valueFormat, groupByExprs, childContextStacker, queryBuilder); // Then: assertThat(result.getKeyField(), @@ -836,14 +828,6 @@ private List givenInitialKTableOf(final String selectQuery) { functionRegistry ); - rowSerde = GenericRowSerDe.from( - FormatInfo.of(Format.JSON, Optional.empty()), - PersistenceSchema.from(initialSchemaKTable.getSchema().valueConnectSchema(), false), - null, - () -> null, - "test", - processingLogContext); - final ProjectNode projectNode = (ProjectNode) logicalPlan.getSources().get(0); return projectNode.getProjectSelectExpressions(); } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/ExecutionStep.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/ExecutionStep.java index b1dd8c9e0591..3d07c8598648 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/ExecutionStep.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/ExecutionStep.java @@ -15,6 +15,7 @@ package io.confluent.ksql.execution.plan; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.schema.ksql.LogicalSchema; import java.util.List; public interface ExecutionStep { @@ -23,4 +24,8 @@ public interface ExecutionStep { List> getSources(); S build(KsqlQueryBuilder queryBuilder); + + default LogicalSchema getSchema() { + return getProperties().getSchema(); + } } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupBy.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupBy.java index 3a2ec496926c..5fe42a01464f 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupBy.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupBy.java @@ -15,22 +15,26 @@ package io.confluent.ksql.execution.plan; import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.expression.tree.Expression; import java.util.Collections; import java.util.List; import java.util.Objects; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; @Immutable -public class StreamGroupBy implements ExecutionStep { +public class StreamGroupBy implements ExecutionStep> { private final ExecutionStepProperties properties; - private final ExecutionStep source; + private final ExecutionStep> source; private final Formats formats; private final List groupByExpressions; public StreamGroupBy( final ExecutionStepProperties properties, - final ExecutionStep source, + final ExecutionStep> source, final Formats formats, final List groupByExpressions) { this.properties = Objects.requireNonNull(properties, "properties"); @@ -53,8 +57,16 @@ public List> getSources() { return Collections.singletonList(source); } + public Formats getFormats() { + return formats; + } + + public ExecutionStep> getSource() { + return source; + } + @Override - public G build(final KsqlQueryBuilder streamsBuilder) { + public KGroupedStream build(final KsqlQueryBuilder streamsBuilder) { throw new UnsupportedOperationException(); } @@ -66,7 +78,7 @@ public boolean equals(final Object o) { if (o == null || getClass() != o.getClass()) { return false; } - final StreamGroupBy that = (StreamGroupBy) o; + final StreamGroupBy that = (StreamGroupBy) o; return Objects.equals(properties, that.properties) && Objects.equals(source, that.source) && Objects.equals(formats, that.formats) @@ -75,7 +87,6 @@ public boolean equals(final Object o) { @Override public int hashCode() { - return Objects.hash(properties, source, formats, groupByExpressions); } } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupByKey.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupByKey.java new file mode 100644 index 000000000000..5345332cd6af --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamGroupByKey.java @@ -0,0 +1,84 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License; you may not use this file + * except in compliance with the License. You may obtain a copy of the License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.plan; + +import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; + +@Immutable +public class StreamGroupByKey implements ExecutionStep> { + private final ExecutionStepProperties properties; + private final ExecutionStep> source; + private final Formats formats; + + public StreamGroupByKey( + final ExecutionStepProperties properties, + final ExecutionStep> source, + final Formats formats) { + this.properties = Objects.requireNonNull(properties, "properties"); + this.formats = Objects.requireNonNull(formats, "formats"); + this.source = Objects.requireNonNull(source, "source"); + } + + @Override + public ExecutionStepProperties getProperties() { + return properties; + } + + @Override + public List> getSources() { + return Collections.singletonList(source); + } + + public ExecutionStep> getSource() { + return source; + } + + public Formats getFormats() { + return formats; + } + + @Override + public KGroupedStream build(final KsqlQueryBuilder streamsBuilder) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final StreamGroupByKey that = (StreamGroupByKey) o; + return Objects.equals(properties, that.properties) + && Objects.equals(source, that.source) + && Objects.equals(formats, that.formats); + } + + @Override + public int hashCode() { + + return Objects.hash(properties, source, formats); + } +} diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableGroupBy.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableGroupBy.java index fb031952272b..aba5df392165 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableGroupBy.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableGroupBy.java @@ -15,22 +15,26 @@ package io.confluent.ksql.execution.plan; import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.expression.tree.Expression; import java.util.Collections; import java.util.List; import java.util.Objects; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; @Immutable -public class TableGroupBy implements ExecutionStep { +public class TableGroupBy implements ExecutionStep> { private final ExecutionStepProperties properties; - private final ExecutionStep source; + private final ExecutionStep> source; private final Formats formats; private final List groupByExpressions; public TableGroupBy( final ExecutionStepProperties properties, - final ExecutionStep source, + final ExecutionStep> source, final Formats formats, final List groupByExpressions ) { @@ -50,8 +54,20 @@ public List> getSources() { return Collections.singletonList(source); } + public Formats getFormats() { + return formats; + } + + public List getGroupByExpressions() { + return groupByExpressions; + } + + public ExecutionStep> getSource() { + return source; + } + @Override - public G build(final KsqlQueryBuilder builder) { + public KGroupedTable build(final KsqlQueryBuilder builder) { throw new UnsupportedOperationException(); } @@ -63,7 +79,7 @@ public boolean equals(final Object o) { if (o == null || getClass() != o.getClass()) { return false; } - final TableGroupBy that = (TableGroupBy) o; + final TableGroupBy that = (TableGroupBy) o; return Objects.equals(properties, that.properties) && Objects.equals(source, that.source) && Objects.equals(formats, that.formats) diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java index aa689dfd3c25..9895350f76c9 100644 --- a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java @@ -28,6 +28,7 @@ import io.confluent.ksql.execution.plan.StreamAggregate; import io.confluent.ksql.execution.plan.StreamFilter; import io.confluent.ksql.execution.plan.StreamGroupBy; +import io.confluent.ksql.execution.plan.StreamGroupByKey; import io.confluent.ksql.execution.plan.StreamMapValues; import io.confluent.ksql.execution.plan.StreamSelectKey; import io.confluent.ksql.execution.plan.StreamSink; @@ -320,12 +321,11 @@ public static TableTableJoin> tableTableJoin( ); } - public static StreamGroupBy, KGroupedStream> - streamGroupBy( - final QueryContext.Stacker stacker, - final ExecutionStep> sourceStep, - final Formats format, - final List groupingExpressions + public static StreamGroupBy streamGroupBy( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final Formats format, + final List groupingExpressions ) { final QueryContext queryContext = stacker.getQueryContext(); return new StreamGroupBy<>( @@ -336,6 +336,19 @@ public static TableTableJoin> tableTableJoin( ); } + public static StreamGroupByKey streamGroupByKey( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final Formats formats + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamGroupByKey( + sourceStep.getProperties().withQueryContext(queryContext), + sourceStep, + formats + ); + } + public static TableAggregate, KGroupedTable> tableAggregate( final QueryContext.Stacker stacker, @@ -355,12 +368,11 @@ public static TableTableJoin> tableTableJoin( ); } - public static TableGroupBy, KGroupedTable> - tableGroupBy( - final QueryContext.Stacker stacker, - final ExecutionStep> sourceStep, - final Formats format, - final List groupingExpressions + public static TableGroupBy tableGroupBy( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final Formats format, + final List groupingExpressions ) { final QueryContext queryContext = stacker.getQueryContext(); return new TableGroupBy<>( diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/GroupByMapper.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/GroupByMapper.java similarity index 77% rename from ksql-engine/src/main/java/io/confluent/ksql/structured/GroupByMapper.java rename to ksql-streams/src/main/java/io/confluent/ksql/execution/streams/GroupByMapper.java index 8b89ac745589..28420f47c560 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/GroupByMapper.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/GroupByMapper.java @@ -13,12 +13,11 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.structured; +package io.confluent.ksql.execution.streams; import com.google.common.collect.ImmutableList; import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.codegen.ExpressionMetadata; -import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.util.StructKeyUtil; import java.util.List; import java.util.Objects; @@ -29,11 +28,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class GroupByMapper implements KeyValueMapper { +class GroupByMapper implements KeyValueMapper { private static final Logger LOG = LoggerFactory.getLogger(GroupByMapper.class); - private static final String GROUP_BY_COLUMN_SEPARATOR = "|+|"; + private static final String GROUP_BY_VALUE_SEPARATOR = "|+|"; private final List expressions; @@ -45,20 +44,14 @@ class GroupByMapper implements KeyValueMapper { } @Override - public Struct apply(final Object key, final GenericRow row) { + public Struct apply(final K key, final GenericRow row) { final String stringRowKey = IntStream.range(0, expressions.size()) .mapToObj(idx -> processColumn(idx, expressions.get(idx), row)) - .collect(Collectors.joining(GROUP_BY_COLUMN_SEPARATOR)); + .collect(Collectors.joining(GROUP_BY_VALUE_SEPARATOR)); return StructKeyUtil.asStructKey(stringRowKey); } - static String keyNameFor(final List groupByExpressions) { - return groupByExpressions.stream() - .map(Expression::toString) - .collect(Collectors.joining(GROUP_BY_COLUMN_SEPARATOR)); - } - private static String processColumn( final int index, final ExpressionMetadata exp, @@ -71,4 +64,8 @@ private static String processColumn( return "null"; } } + + List getExpressionMetadata() { + return expressions; + } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/streams/GroupedFactory.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/GroupedFactory.java similarity index 94% rename from ksql-engine/src/main/java/io/confluent/ksql/streams/GroupedFactory.java rename to ksql-streams/src/main/java/io/confluent/ksql/execution/streams/GroupedFactory.java index 0b4a6abd56d9..7c72e0c38ca6 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/streams/GroupedFactory.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/GroupedFactory.java @@ -13,9 +13,8 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.streams; +package io.confluent.ksql.execution.streams; -import io.confluent.ksql.execution.streams.StreamsUtil; import io.confluent.ksql.util.KsqlConfig; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.streams.kstream.Grouped; diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamGroupByBuilder.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamGroupByBuilder.java new file mode 100644 index 000000000000..f6c65018a3d1 --- /dev/null +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamGroupByBuilder.java @@ -0,0 +1,109 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.codegen.CodeGenRunner; +import io.confluent.ksql.execution.codegen.ExpressionMetadata; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.StreamGroupBy; +import io.confluent.ksql.execution.plan.StreamGroupByKey; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.serde.KeySerde; +import java.util.List; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; + +public final class StreamGroupByBuilder { + private StreamGroupByBuilder() { + } + + public static KGroupedStream build( + final KStream kstream, + final StreamGroupByKey step, + final KsqlQueryBuilder queryBuilder, + final GroupedFactory groupedFactory + ) { + final LogicalSchema sourceSchema = step.getSource().getSchema(); + final QueryContext queryContext = step.getProperties().getQueryContext(); + final Formats formats = step.getFormats(); + final Grouped grouped = buildGrouped( + formats, + sourceSchema, + queryContext, + queryBuilder, + groupedFactory + ); + return kstream.groupByKey(grouped); + } + + public static KGroupedStream build( + final KStream kstream, + final StreamGroupBy step, + final KsqlQueryBuilder queryBuilder, + final GroupedFactory groupedFactory + ) { + final LogicalSchema sourceSchema = step.getSource().getSchema(); + final QueryContext queryContext = step.getProperties().getQueryContext(); + final Formats formats = step.getFormats(); + final Grouped grouped = buildGrouped( + formats, + sourceSchema, + queryContext, + queryBuilder, + groupedFactory + ); + final List groupBy = CodeGenRunner.compileExpressions( + step.getGroupByExpressions().stream(), + "Group By", + sourceSchema, + queryBuilder.getKsqlConfig(), + queryBuilder.getFunctionRegistry() + ); + final GroupByMapper mapper = new GroupByMapper<>(groupBy); + return kstream.filter((key, value) -> value != null).groupBy(mapper, grouped); + } + + private static Grouped buildGrouped( + final Formats formats, + final LogicalSchema schema, + final QueryContext queryContext, + final KsqlQueryBuilder queryBuilder, + final GroupedFactory groupedFactory + ) { + final PhysicalSchema physicalSchema = PhysicalSchema.from( + schema, + formats.getOptions() + ); + final KeySerde keySerde = queryBuilder.buildKeySerde( + formats.getKeyFormat().getFormatInfo(), + physicalSchema, + queryContext + ); + final Serde valSerde = queryBuilder.buildValueSerde( + formats.getValueFormat().getFormatInfo(), + physicalSchema, + queryContext + ); + return groupedFactory.create(StreamsUtil.buildOpName(queryContext), keySerde, valSerde); + } +} diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableGroupByBuilder.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableGroupByBuilder.java new file mode 100644 index 000000000000..6af556396bce --- /dev/null +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableGroupByBuilder.java @@ -0,0 +1,100 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.codegen.CodeGenRunner; +import io.confluent.ksql.execution.codegen.ExpressionMetadata; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.TableGroupBy; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.serde.KeySerde; +import java.util.List; +import java.util.Objects; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.KeyValueMapper; + +public final class TableGroupByBuilder { + private TableGroupByBuilder() { + } + + public static KGroupedTable build( + final KTable ktable, + final TableGroupBy step, + final KsqlQueryBuilder queryBuilder, + final GroupedFactory groupedFactory + ) { + final LogicalSchema sourceSchema = step.getSource().getSchema(); + final QueryContext queryContext = step.getProperties().getQueryContext(); + final Formats formats = step.getFormats(); + final PhysicalSchema physicalSchema = PhysicalSchema.from( + sourceSchema, + formats.getOptions() + ); + final KeySerde keySerde = queryBuilder.buildKeySerde( + formats.getKeyFormat().getFormatInfo(), + physicalSchema, + queryContext + ); + final Serde valSerde = queryBuilder.buildValueSerde( + formats.getValueFormat().getFormatInfo(), + physicalSchema, + queryContext + ); + final Grouped grouped = groupedFactory.create( + StreamsUtil.buildOpName(queryContext), + keySerde, + valSerde + ); + final List groupBy = CodeGenRunner.compileExpressions( + step.getGroupByExpressions().stream(), + "Group By", + sourceSchema, + queryBuilder.getKsqlConfig(), + queryBuilder.getFunctionRegistry() + ); + final GroupByMapper mapper = new GroupByMapper<>(groupBy); + return ktable + .filter((key, value) -> value != null) + .groupBy(new TableKeyValueMapper<>(mapper), grouped); + } + + public static final class TableKeyValueMapper + implements KeyValueMapper> { + private final GroupByMapper groupByMapper; + + private TableKeyValueMapper(final GroupByMapper groupByMapper) { + this.groupByMapper = Objects.requireNonNull(groupByMapper, "groupByMapper"); + } + + @Override + public KeyValue apply(final K key, final GenericRow value) { + return new KeyValue<>(groupByMapper.apply(key, value), value); + } + + GroupByMapper getGroupByMapper() { + return groupByMapper; + } + } +} diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/GroupByMapperTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/GroupByMapperTest.java similarity index 79% rename from ksql-engine/src/test/java/io/confluent/ksql/structured/GroupByMapperTest.java rename to ksql-streams/src/test/java/io/confluent/ksql/execution/streams/GroupByMapperTest.java index def4018422ae..e3664fb88c4c 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/GroupByMapperTest.java +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/GroupByMapperTest.java @@ -13,7 +13,7 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.structured; +package io.confluent.ksql.execution.streams; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; @@ -47,21 +47,21 @@ public class GroupByMapperTest { @Mock(MockType.NICE) private GenericRow row; - private GroupByMapper mapper; + private GroupByMapper mapper; @Before public void setUp() { - mapper = new GroupByMapper(ImmutableList.of(groupBy0, groupBy1)); + mapper = new GroupByMapper<>(ImmutableList.of(groupBy0, groupBy1)); } @Test(expected = NullPointerException.class) public void shouldThrowOnNullParam() { - new GroupByMapper(null); + new GroupByMapper(null); } @Test(expected = IllegalArgumentException.class) public void shouldThrowOnEmptyParam() { - new GroupByMapper(Collections.emptyList()); + new GroupByMapper(Collections.emptyList()); } @Test @@ -72,7 +72,7 @@ public void shouldGenerateGroupByKey() { EasyMock.replay(groupBy0, groupBy1); // When: - final Struct result = mapper.apply("key", row); + final Struct result = mapper.apply(StructKeyUtil.asStructKey("key"), row); // Then: assertThat(result, is(StructKeyUtil.asStructKey("result0|+|result1"))); @@ -86,7 +86,7 @@ public void shouldSupportNullValues() { EasyMock.replay(groupBy0, groupBy1); // When: - final Struct result = mapper.apply("key", row); + final Struct result = mapper.apply(StructKeyUtil.asStructKey("key"), row); // Then: assertThat(result, is(StructKeyUtil.asStructKey("null|+|result1"))); @@ -100,22 +100,9 @@ public void shouldUseNullIfExpressionThrows() { EasyMock.replay(groupBy0, groupBy1); // When: - final Struct result = mapper.apply("key", row); + final Struct result = mapper.apply(StructKeyUtil.asStructKey("key"), row); // Then: assertThat(result, is(StructKeyUtil.asStructKey("null|+|result1"))); } - - @Test - public void shouldGetKeyName() { - // Given: - final Expression exp0 = new QualifiedNameReference(QualifiedName.of("Fred", "f1")); - final Expression exp1 = new QualifiedNameReference(QualifiedName.of("Bob", "b1")); - - // When: - final String result = GroupByMapper.keyNameFor(ImmutableList.of(exp0, exp1)); - - // Then: - assertThat(result, is("Fred.f1|+|Bob.b1")); - } -} \ No newline at end of file +} diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamGroupByBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamGroupByBuilderTest.java new file mode 100644 index 000000000000..c008d1739d6b --- /dev/null +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamGroupByBuilderTest.java @@ -0,0 +1,261 @@ +package io.confluent.ksql.execution.streams; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.QualifiedName; +import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.ExecutionStepProperties; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.StreamGroupBy; +import io.confluent.ksql.execution.plan.StreamGroupByKey; +import io.confluent.ksql.execution.util.StructKeyUtil; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.query.QueryId; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.Format; +import io.confluent.ksql.serde.FormatInfo; +import io.confluent.ksql.serde.KeyFormat; +import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; +import io.confluent.ksql.util.KsqlConfig; +import java.util.List; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Predicate; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +public class StreamGroupByBuilderTest { + private static final String ALIAS = "SOURCE"; + private static final LogicalSchema SCHEMA = LogicalSchema.builder() + .valueColumn("PAC", SqlTypes.BIGINT) + .valueColumn("MAN", SqlTypes.STRING) + .build() + .withAlias(ALIAS) + .withMetaAndKeyColsInValue(); + private static final PhysicalSchema PHYSICAL_SCHEMA = + PhysicalSchema.from(SCHEMA, SerdeOption.none()); + private static final List GROUP_BY_EXPRESSIONS = ImmutableList.of( + columnReference("PAC"), + columnReference("MAN") + ); + private static final QueryContext SOURCE_CTX = + new QueryContext.Stacker(new QueryId("qid")).push("foo").push("source").getQueryContext(); + private static final QueryContext STEP_CTX = + new QueryContext.Stacker(new QueryId("qid")).push("foo").push("groupby").getQueryContext(); + private static final ExecutionStepProperties SOURCE_PROPERTIES + = new DefaultExecutionStepProperties(SCHEMA, SOURCE_CTX); + private static final ExecutionStepProperties PROPERTIES = new DefaultExecutionStepProperties( + SCHEMA, + STEP_CTX + ); + private static final Formats FORMATS = Formats.of( + KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA)), + ValueFormat.of(FormatInfo.of(Format.JSON)), + SerdeOption.none() + ); + + @Mock + private KsqlQueryBuilder queryBuilder; + @Mock + private KsqlConfig ksqlConfig; + @Mock + private FunctionRegistry functionRegistry; + @Mock + private GroupedFactory groupedFactory; + @Mock + private ExecutionStep sourceStep; + @Mock + private KeySerde keySerde; + @Mock + private Serde valueSerde; + @Mock + private Grouped grouped; + @Mock + private KStream sourceStream; + @Mock + private KStream filteredStream; + @Mock + private KGroupedStream groupedStream; + @Captor + private ArgumentCaptor> mapperCaptor; + @Captor + private ArgumentCaptor> predicateCaptor; + + private StreamGroupBy streamGroupBy; + private StreamGroupByKey streamGroupByKey; + + @Rule + public final MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Before + @SuppressWarnings("unchecked") + public void init() { + when(queryBuilder.getKsqlConfig()).thenReturn(ksqlConfig); + when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); + when(queryBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(valueSerde); + when(groupedFactory.create(any(), any(KeySerde.class), any())).thenReturn(grouped); + when(sourceStream.groupByKey(any(Grouped.class))).thenReturn(groupedStream); + when(sourceStream.filter(any())).thenReturn(filteredStream); + when(filteredStream.groupBy(any(KeyValueMapper.class), any(Grouped.class))) + .thenReturn(groupedStream); + when(sourceStep.getProperties()).thenReturn(SOURCE_PROPERTIES); + when(sourceStep.getSchema()).thenReturn(SCHEMA); + streamGroupBy = new StreamGroupBy<>( + PROPERTIES, + sourceStep, + FORMATS, + GROUP_BY_EXPRESSIONS + ); + streamGroupByKey = new StreamGroupByKey(PROPERTIES, sourceStep, FORMATS); + } + + @Test + public void shouldPerformGroupByCorrectly() { + // When: + final KGroupedStream result = + StreamGroupByBuilder.build(sourceStream, streamGroupBy, queryBuilder, groupedFactory); + + // Then: + assertThat(result, is(groupedStream)); + verify(sourceStream).filter(any()); + verify(filteredStream).groupBy(mapperCaptor.capture(), same(grouped)); + verifyNoMoreInteractions(filteredStream, sourceStream); + final GroupByMapper mapper = mapperCaptor.getValue(); + assertThat(mapper.getExpressionMetadata(), hasSize(2)); + assertThat( + mapper.getExpressionMetadata().get(0).getExpression(), + equalTo(GROUP_BY_EXPRESSIONS.get(0)) + ); + assertThat( + mapper.getExpressionMetadata().get(1).getExpression(), + equalTo(GROUP_BY_EXPRESSIONS.get(1)) + ); + } + + @Test + public void shouldFilterNullRowsBeforeGroupBy() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupBy, queryBuilder, groupedFactory); + + // Then: + verify(sourceStream).filter(predicateCaptor.capture()); + final Predicate predicate = predicateCaptor.getValue(); + assertThat(predicate.test(StructKeyUtil.asStructKey("foo"), new GenericRow()), is(true)); + assertThat(predicate.test(StructKeyUtil.asStructKey("foo"), null), is(false)); + } + + @Test + public void shouldBuildGroupedCorrectlyForGroupBy() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupBy, queryBuilder, groupedFactory); + + // Then: + verify(groupedFactory).create("foo-groupby", keySerde, valueSerde); + } + + @Test + public void shouldBuildKeySerdeCorrectlyForGroupBy() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupBy, queryBuilder, groupedFactory); + + // Then: + verify(queryBuilder).buildKeySerde( + FORMATS.getKeyFormat().getFormatInfo(), + PHYSICAL_SCHEMA, + STEP_CTX + ); + } + + @Test + public void shouldBuildValueSerdeCorrectlyForGroupBy() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupBy, queryBuilder, groupedFactory); + + // Then: + verify(queryBuilder).buildValueSerde( + FORMATS.getValueFormat().getFormatInfo(), + PHYSICAL_SCHEMA, + STEP_CTX + ); + } + + @Test + public void shouldPerformGroupByKeyCorrectly() { + // When: + final KGroupedStream result = + StreamGroupByBuilder.build(sourceStream, streamGroupByKey, queryBuilder, groupedFactory); + + // Then: + assertThat(result, is(groupedStream)); + verify(sourceStream).groupByKey(grouped); + verifyNoMoreInteractions(sourceStream); + } + + @Test + public void shouldBuildGroupedCorrectlyForGroupByKey() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupByKey, queryBuilder, groupedFactory); + + // Then: + verify(groupedFactory).create("foo-groupby", keySerde, valueSerde); + } + + @Test + public void shouldBuildKeySerdeCorrectlyForGroupByKey() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupByKey, queryBuilder, groupedFactory); + + // Then: + verify(queryBuilder).buildKeySerde( + FORMATS.getKeyFormat().getFormatInfo(), + PHYSICAL_SCHEMA, + STEP_CTX); + } + + @Test + public void shouldBuildValueSerdeCorrectlyForGroupByKey() { + // When: + StreamGroupByBuilder.build(sourceStream, streamGroupByKey, queryBuilder, groupedFactory); + + // Then: + verify(queryBuilder).buildValueSerde( + FORMATS.getValueFormat().getFormatInfo(), + PHYSICAL_SCHEMA, + STEP_CTX + ); + } + + private static Expression columnReference(final String column) { + return new QualifiedNameReference(QualifiedName.of(ALIAS, column)); + } +} \ No newline at end of file diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableGroupByBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableGroupByBuilderTest.java new file mode 100644 index 000000000000..d62b898490e5 --- /dev/null +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableGroupByBuilderTest.java @@ -0,0 +1,212 @@ +package io.confluent.ksql.execution.streams; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.QualifiedName; +import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.ExecutionStepProperties; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.TableGroupBy; +import io.confluent.ksql.execution.streams.TableGroupByBuilder.TableKeyValueMapper; +import io.confluent.ksql.execution.util.StructKeyUtil; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.query.QueryId; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.Format; +import io.confluent.ksql.serde.FormatInfo; +import io.confluent.ksql.serde.KeyFormat; +import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; +import io.confluent.ksql.util.KsqlConfig; +import java.util.List; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Predicate; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +public class TableGroupByBuilderTest { + private static final String ALIAS = "SOURCE"; + private static final LogicalSchema SCHEMA = LogicalSchema.builder() + .valueColumn("PAC", SqlTypes.BIGINT) + .valueColumn("MAN", SqlTypes.STRING) + .build() + .withAlias(ALIAS) + .withMetaAndKeyColsInValue(); + private static final PhysicalSchema PHYSICAL_SCHEMA = PhysicalSchema.from(SCHEMA, SerdeOption.none()); + + private static final List GROUPBY_EXPRESSIONS = ImmutableList.of( + columnReference("PAC"), + columnReference("MAN") + ); + private static final QueryContext SOURCE_CONTEXT = + new QueryContext.Stacker(new QueryId("qid")).push("foo").push("source").getQueryContext(); + private static final QueryContext STEP_CONTEXT = + new QueryContext.Stacker(new QueryId("qid")).push("foo").push("groupby").getQueryContext(); + private static final ExecutionStepProperties SOURCE_PROPERTIES = + new DefaultExecutionStepProperties(SCHEMA, SOURCE_CONTEXT); + private static final ExecutionStepProperties PROPERTIES = new DefaultExecutionStepProperties( + SCHEMA, + STEP_CONTEXT + ); + private static final Formats FORMATS = Formats.of( + KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA)), + ValueFormat.of(FormatInfo.of(Format.JSON)), + SerdeOption.none() + ); + + @Mock + private KsqlQueryBuilder queryBuilder; + @Mock + private KsqlConfig ksqlConfig; + @Mock + private FunctionRegistry functionRegistry; + @Mock + private GroupedFactory groupedFactory; + @Mock + private ExecutionStep> sourceStep; + @Mock + private KeySerde keySerde; + @Mock + private Serde valueSerde; + @Mock + private Grouped grouped; + @Mock + private KTable sourceTable; + @Mock + private KTable filteredTable; + @Mock + private KGroupedTable groupedTable; + @Captor + private ArgumentCaptor> mapperCaptor; + @Captor + private ArgumentCaptor> predicateCaptor; + + private TableGroupBy groupBy; + + @Rule + public final MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Before + @SuppressWarnings("unchecked") + public void init() { + when(queryBuilder.getKsqlConfig()).thenReturn(ksqlConfig); + when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); + when(queryBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(valueSerde); + when(groupedFactory.create(any(), any(KeySerde.class), any())).thenReturn(grouped); + when(sourceTable.filter(any())).thenReturn(filteredTable); + when(filteredTable.groupBy(any(KeyValueMapper.class), any(Grouped.class))) + .thenReturn(groupedTable); + when(sourceStep.getProperties()).thenReturn(SOURCE_PROPERTIES); + when(sourceStep.getSchema()).thenReturn(SCHEMA); + groupBy = new TableGroupBy<>( + PROPERTIES, + sourceStep, + FORMATS, + GROUPBY_EXPRESSIONS + ); + } + + @Test + public void shouldPerformGroupByCorrectly() { + // When: + final KGroupedTable result = + TableGroupByBuilder.build(sourceTable, groupBy, queryBuilder, groupedFactory); + + // Then: + assertThat(result, is(groupedTable)); + verify(sourceTable).filter(any()); + verify(filteredTable).groupBy(mapperCaptor.capture(), same(grouped)); + verifyNoMoreInteractions(filteredTable, sourceTable); + final GroupByMapper mapper = mapperCaptor.getValue().getGroupByMapper(); + assertThat(mapper.getExpressionMetadata(), hasSize(2)); + assertThat( + mapper.getExpressionMetadata().get(0).getExpression(), + equalTo(GROUPBY_EXPRESSIONS.get(0)) + ); + assertThat( + mapper.getExpressionMetadata().get(1).getExpression(), + equalTo(GROUPBY_EXPRESSIONS.get(1)) + ); + } + + @Test + public void shouldFilterNullRowsBeforeGroupBy() { + // When: + TableGroupByBuilder.build(sourceTable, groupBy, queryBuilder, groupedFactory); + + // Then: + verify(sourceTable).filter(predicateCaptor.capture()); + final Predicate predicate = predicateCaptor.getValue(); + assertThat(predicate.test(StructKeyUtil.asStructKey("key"), new GenericRow()), is(true)); + assertThat(predicate.test(StructKeyUtil.asStructKey("key"), null), is(false)); + } + + @Test + public void shouldBuildGroupedCorrectlyForGroupBy() { + // When: + TableGroupByBuilder.build(sourceTable, groupBy, queryBuilder, groupedFactory); + + // Then: + verify(groupedFactory).create("foo-groupby", keySerde, valueSerde); + } + + @Test + public void shouldBuildKeySerdeCorrectlyForGroupBy() { + // When: + TableGroupByBuilder.build(sourceTable, groupBy, queryBuilder, groupedFactory); + + // Then: + verify(queryBuilder).buildKeySerde( + FORMATS.getKeyFormat().getFormatInfo(), + PHYSICAL_SCHEMA, + STEP_CONTEXT + ); + } + + @Test + public void shouldBuildValueSerdeCorrectlyForGroupBy() { + // When: + TableGroupByBuilder.build(sourceTable, groupBy, queryBuilder, groupedFactory); + + // Then: + verify(queryBuilder).buildValueSerde( + FORMATS.getValueFormat().getFormatInfo(), + PHYSICAL_SCHEMA, + STEP_CONTEXT + ); + } + + private static Expression columnReference(final String column) { + return new QualifiedNameReference(QualifiedName.of(ALIAS, column)); + } +} \ No newline at end of file