Skip to content

Commit

Permalink
feat: move groupBy into plan builders (#3359)
Browse files Browse the repository at this point in the history
This patch moves the code for regrouping streams/tables into plan
builders. This also required adding a new execution step for
groupByKey, which we missed the first go-round.
  • Loading branch information
rodesai authored Sep 19, 2019
1 parent e4b3275 commit 730c913
Show file tree
Hide file tree
Showing 22 changed files with 1,031 additions and 238 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,6 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
.getKsqlTopic()
.getValueFormat();

final Serde<GenericRow> genericRowSerde = builder.buildValueSerde(
valueFormat.getFormatInfo(),
PhysicalSchema.from(prepareSchema, SerdeOption.none()),
groupByContext.getQueryContext()
);

final List<Expression> internalGroupByColumns = internalSchema.resolveGroupByExpressions(
getGroupByExpressions(),
aggregateArgExpanded,
Expand All @@ -235,9 +229,9 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {

final SchemaKGroupedStream schemaKGroupedStream = aggregateArgExpanded.groupBy(
valueFormat,
genericRowSerde,
internalGroupByColumns,
groupByContext
groupByContext,
builder
);

// Aggregate computations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.streams;

import io.confluent.ksql.execution.streams.GroupedFactory;
import io.confluent.ksql.execution.streams.MaterializedFactory;
import io.confluent.ksql.util.KsqlConfig;
import java.util.Objects;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import com.google.common.collect.ImmutableList;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.codegen.CodeGenRunner;
import io.confluent.ksql.execution.codegen.ExpressionMetadata;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.context.QueryLoggerUtil;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
Expand All @@ -37,12 +35,15 @@
import io.confluent.ksql.execution.plan.LogicalSchemaWithMetaAndKeyFields;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.plan.StreamFilter;
import io.confluent.ksql.execution.plan.StreamGroupBy;
import io.confluent.ksql.execution.plan.StreamGroupByKey;
import io.confluent.ksql.execution.plan.StreamMapValues;
import io.confluent.ksql.execution.plan.StreamSelectKey;
import io.confluent.ksql.execution.plan.StreamSource;
import io.confluent.ksql.execution.plan.StreamToTable;
import io.confluent.ksql.execution.streams.ExecutionStepFactory;
import io.confluent.ksql.execution.streams.StreamFilterBuilder;
import io.confluent.ksql.execution.streams.StreamGroupByBuilder;
import io.confluent.ksql.execution.streams.StreamMapValuesBuilder;
import io.confluent.ksql.execution.streams.StreamSelectKeyBuilder;
import io.confluent.ksql.execution.streams.StreamSourceBuilder;
Expand Down Expand Up @@ -72,12 +73,11 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.streams.Topology.AutoOffsetReset;
import org.apache.kafka.streams.kstream.Grouped;
import org.apache.kafka.streams.kstream.JoinWindows;
import org.apache.kafka.streams.kstream.KGroupedStream;
import org.apache.kafka.streams.kstream.KStream;
import org.apache.kafka.streams.kstream.KTable;
import org.apache.kafka.streams.kstream.Produced;
Expand All @@ -92,6 +92,8 @@ public class SchemaKStream<K> {
private static final FormatOptions FORMAT_OPTIONS =
FormatOptions.of(IdentifierUtil::needsQuotes);

static final String GROUP_BY_COLUMN_SEPARATOR = "|+|";

public enum Type { SOURCE, PROJECT, FILTER, AGGREGATE, SINK, REKEY, JOIN }

final KStream<K, GenericRow> kstream;
Expand Down Expand Up @@ -777,73 +779,38 @@ private boolean rekeyRequired(final List<Expression> groupByExpressions) {

public SchemaKGroupedStream groupBy(
final ValueFormat valueFormat,
final Serde<GenericRow> valSerde,
final List<Expression> groupByExpressions,
final QueryContext.Stacker contextStacker
final QueryContext.Stacker contextStacker,
final KsqlQueryBuilder queryBuilder
) {
final boolean rekey = rekeyRequired(groupByExpressions);
final KeyFormat rekeyedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo());
if (!rekey) {
final Grouped<K, GenericRow> grouped = streamsFactories.getGroupedFactory()
.create(
StreamsUtil.buildOpName(contextStacker.getQueryContext()),
keySerde,
valSerde
);

final KGroupedStream kgroupedStream = kstream.groupByKey(grouped);

final KeySerde<Struct> structKeySerde = getGroupByKeyKeySerde();

final ExecutionStep<KGroupedStream<Struct, GenericRow>> step =
ExecutionStepFactory.streamGroupBy(
contextStacker,
sourceStep,
Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
);
return new SchemaKGroupedStream(
kgroupedStream,
step,
keyFormat,
structKeySerde,
keyField,
Collections.singletonList(this),
ksqlConfig,
functionRegistry
);
return groupByKey(rekeyedKeyFormat, valueFormat, contextStacker, queryBuilder);
}

final GroupBy groupBy = new GroupBy(groupByExpressions);

final KeySerde<Struct> groupedKeySerde = keySerde
.rebind(StructKeyUtil.ROWKEY_SERIALIZED_SCHEMA);

final Grouped<Struct, GenericRow> grouped = streamsFactories.getGroupedFactory()
.create(
StreamsUtil.buildOpName(contextStacker.getQueryContext()),
groupedKeySerde,
valSerde
);

final KGroupedStream kgroupedStream = kstream
.filter((key, value) -> value != null)
.groupBy(groupBy.mapper, grouped);

final String aggregateKeyName = groupedKeyNameFor(groupByExpressions);
final LegacyField legacyKeyField = LegacyField
.notInSchema(groupBy.aggregateKeyName, SqlTypes.STRING);

final Optional<String> newKeyCol = getSchema().findValueColumn(groupBy.aggregateKeyName)
.notInSchema(aggregateKeyName, SqlTypes.STRING);
final Optional<String> newKeyCol = getSchema().findValueColumn(aggregateKeyName)
.map(Column::name);
final ExecutionStep<KGroupedStream<Struct, GenericRow>> source =
ExecutionStepFactory.streamGroupBy(
contextStacker,
sourceStep,
Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
);

final StreamGroupBy<K> source = ExecutionStepFactory.streamGroupBy(
contextStacker,
sourceStep,
Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
);
return new SchemaKGroupedStream(
kgroupedStream,
StreamGroupByBuilder.build(
kstream,
source,
queryBuilder,
streamsFactories.getGroupedFactory()
),
source,
rekeyedKeyFormat,
groupedKeySerde,
Expand All @@ -854,6 +821,37 @@ public SchemaKGroupedStream groupBy(
);
}

@SuppressWarnings("unchecked")
private SchemaKGroupedStream groupByKey(
final KeyFormat rekeyedKeyFormat,
final ValueFormat valueFormat,
final QueryContext.Stacker contextStacker,
final KsqlQueryBuilder queryBuilder
) {
final KeySerde<Struct> structKeySerde = getGroupByKeyKeySerde();
final StreamGroupByKey step =
ExecutionStepFactory.streamGroupByKey(
contextStacker,
(ExecutionStep) sourceStep,
Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none())
);
return new SchemaKGroupedStream(
StreamGroupByBuilder.build(
(KStream) kstream,
step,
queryBuilder,
streamsFactories.getGroupedFactory()
),
step,
keyFormat,
structKeySerde,
keyField,
Collections.singletonList(this),
ksqlConfig,
functionRegistry
);
}

@SuppressWarnings("unchecked")
private KeySerde<Struct> getGroupByKeyKeySerde() {
if (keySerde.isWindowed()) {
Expand Down Expand Up @@ -920,18 +918,10 @@ public FunctionRegistry getFunctionRegistry() {
return functionRegistry;
}

class GroupBy {

final String aggregateKeyName;
final GroupByMapper mapper;

GroupBy(final List<Expression> expressions) {
final List<ExpressionMetadata> groupBy = CodeGenRunner.compileExpressions(
expressions.stream(), "Group By", getSchema(), ksqlConfig, functionRegistry);

this.mapper = new GroupByMapper(groupBy);
this.aggregateKeyName = GroupByMapper.keyNameFor(expressions);
}
String groupedKeyNameFor(final List<Expression> groupByExpressions) {
return groupByExpressions.stream()
.map(Expression::toString)
.collect(Collectors.joining(GROUP_BY_COLUMN_SEPARATOR));
}

protected static class KsqlValueJoiner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
import io.confluent.ksql.execution.plan.JoinType;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.plan.TableFilter;
import io.confluent.ksql.execution.plan.TableGroupBy;
import io.confluent.ksql.execution.plan.TableMapValues;
import io.confluent.ksql.execution.streams.ExecutionStepFactory;
import io.confluent.ksql.execution.streams.StreamsUtil;
import io.confluent.ksql.execution.streams.TableFilterBuilder;
import io.confluent.ksql.execution.streams.TableGroupByBuilder;
import io.confluent.ksql.execution.streams.TableMapValuesBuilder;
import io.confluent.ksql.execution.util.StructKeyUtil;
import io.confluent.ksql.function.FunctionRegistry;
Expand All @@ -51,9 +52,6 @@
import java.util.Set;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.kstream.Grouped;
import org.apache.kafka.streams.kstream.KGroupedTable;
import org.apache.kafka.streams.kstream.KStream;
import org.apache.kafka.streams.kstream.KTable;
import org.apache.kafka.streams.kstream.Produced;
Expand Down Expand Up @@ -231,48 +229,37 @@ public ExecutionStep<KTable<K, GenericRow>> getSourceTableStep() {
}

@Override
@SuppressWarnings("unchecked")
public SchemaKGroupedStream groupBy(
final ValueFormat valueFormat,
final Serde<GenericRow> valSerde,
final List<Expression> groupByExpressions,
final QueryContext.Stacker contextStacker
final QueryContext.Stacker contextStacker,
final KsqlQueryBuilder queryBuilder
) {

final GroupBy groupBy = new GroupBy(groupByExpressions);
final KeyFormat groupedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo());

final KeySerde<Struct> groupedKeySerde = keySerde
.rebind(StructKeyUtil.ROWKEY_SERIALIZED_SCHEMA);

final Grouped<Struct, GenericRow> grouped = streamsFactories.getGroupedFactory()
.create(
StreamsUtil.buildOpName(contextStacker.getQueryContext()),
groupedKeySerde,
valSerde
);

final KGroupedTable kgroupedTable = ktable
.filter((key, value) -> value != null)
.groupBy(
(key, value) -> new KeyValue<>(groupBy.mapper.apply(key, value), value),
grouped
);
final String aggregateKeyName = groupedKeyNameFor(groupByExpressions);
final LegacyField legacyKeyField = LegacyField.notInSchema(aggregateKeyName, SqlTypes.STRING);
final Optional<String> newKeyField =
getSchema().findValueColumn(aggregateKeyName).map(Column::fullName);

final LegacyField legacyKeyField = LegacyField
.notInSchema(groupBy.aggregateKeyName, SqlTypes.STRING);

final Optional<String> newKeyField = getSchema().findValueColumn(groupBy.aggregateKeyName)
.map(Column::fullName);

final ExecutionStep<KGroupedTable<Struct, GenericRow>> step =
ExecutionStepFactory.tableGroupBy(
contextStacker,
sourceTableStep,
Formats.of(groupedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
);
final TableGroupBy<K> step = ExecutionStepFactory.tableGroupBy(
contextStacker,
sourceTableStep,
Formats.of(groupedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
);
return new SchemaKGroupedTable(
kgroupedTable,
TableGroupByBuilder.build(
ktable,
step,
queryBuilder,
streamsFactories.getGroupedFactory()
),
step,
groupedKeyFormat,
groupedKeySerde,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.PersistenceSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.serde.FormatInfo;
import io.confluent.ksql.serde.KeySerde;
import io.confluent.ksql.serde.WindowInfo;
import io.confluent.ksql.structured.SchemaKStream;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.serde.FormatInfo;
import io.confluent.ksql.serde.KeySerde;
import io.confluent.ksql.structured.SchemaKStream;
import io.confluent.ksql.testutils.AnalysisTestUtil;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.streams.GroupedFactory;
import io.confluent.ksql.util.KsqlConfig;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.streams.StreamsConfig;
Expand Down
Loading

0 comments on commit 730c913

Please sign in to comment.