From f0fead2d083c2a3a3b36df252c400623f3e280f9 Mon Sep 17 00:00:00 2001 From: rodesai Date: Fri, 20 Sep 2019 01:59:30 -0700 Subject: [PATCH] feat: move aggregations to plan builder This patch moves the code for building aggregations into a plan builder, which requires a few refactorings throughout the codebase: I've moved the windowing pojos into KSQL execution, and split them off from the AST. This way we can serialize these into the aggregation plan nodes to express windows. To express windowed aggregates, I've added a new execution step type called StreamWindowedAggregate. Adding a new type ensures that when we implement the visitor that builds the streams app, the handler for windowed aggregations is type-safe. This patch also includes a refactor of AggregateNode to pass down the aggregation function call expressions rather than the resolved aggregation functions. The code for resolving the function call expressions against the internal schema and building the aggregators, initializers, and undo aggregators has been moved into a class called AggregateParams. The rest of the patch implements the actual aggregation from step builders. --- .../io/confluent/ksql/analyzer/Analyzer.java | 2 +- .../GeneratedTableAggregateFunction.java | 1 + .../ksql/function/udaf/count/CountKudaf.java | 2 +- .../function/udaf/sum/DecimalSumKudaf.java | 2 +- .../function/udaf/sum/DoubleSumKudaf.java | 2 +- .../function/udaf/sum/IntegerSumKudaf.java | 2 +- .../ksql/function/udaf/sum/LongSumKudaf.java | 2 +- .../function/udaf/window/WindowEndKudaf.java | 3 +- .../udaf/window/WindowStartKudaf.java | 3 +- .../ksql/planner/plan/AggregateNode.java | 130 +--- .../ksql/structured/SchemaKGroupedStream.java | 140 ++-- .../ksql/structured/SchemaKGroupedTable.java | 71 +- .../function/KudafUndoAggregatorTest.java | 3 +- .../ksql/function/UdfCompilerTest.java | 1 + .../function/udaf/sum/BaseSumKudafTest.java | 2 +- .../udaf/window/WindowSelectMapperTest.java | 1 + .../structured/SchemaKGroupedStreamTest.java | 452 +++---------- .../structured/SchemaKGroupedTableTest.java | 327 +++------ .../function/TableAggregationFunction.java | 4 +- .../ksql/execution/function/UdafUtil.java | 61 ++ .../function/udaf/KudafAggregator.java | 10 +- .../function/udaf/KudafInitializer.java | 12 +- .../function/udaf/KudafUndoAggregator.java | 12 +- .../udaf/window/WindowSelectMapper.java | 11 +- .../ksql/execution/plan/StreamAggregate.java | 38 +- .../plan/StreamWindowedAggregate.java | 114 ++++ .../ksql/execution/plan/TableAggregate.java | 36 +- .../windows}/HoppingWindowExpression.java | 45 +- .../windows/KsqlWindowExpression.java | 34 + .../windows}/SessionWindowExpression.java | 35 +- .../windows}/TumblingWindowExpression.java | 35 +- .../ksql/execution/windows/WindowVisitor.java | 24 + .../ksql/execution/function/UdafUtilTest.java | 109 +++ .../io/confluent/ksql/parser/AstBuilder.java | 6 +- .../parser/rewrite/StatementRewriter.java | 11 +- .../ksql/parser/tree/AstVisitor.java | 16 - .../parser/tree/KsqlWindowExpression.java | 48 -- .../ksql/parser/tree/WindowExpression.java | 1 + .../parser/rewrite/StatementRewriterTest.java | 7 +- .../tree/HoppingWindowExpressionTest.java | 36 +- .../ksql/parser/tree/ParserModelTest.java | 2 + .../tree/SessionWindowExpressionTest.java | 29 +- .../tree/TumblingWindowExpressionTest.java | 50 +- .../parser/tree/WindowExpressionTest.java | 1 + .../streams/AggregateBuilderUtils.java | 57 ++ .../execution/streams/AggregateParams.java | 96 +++ .../streams/ExecutionStepFactory.java | 66 +- .../streams/StreamAggregateBuilder.java | 227 +++++++ .../streams/TableAggregateBuilder.java | 76 +++ .../streams/AggregateParamsTest.java | 176 +++++ .../streams/StreamAggregateBuilderTest.java | 619 ++++++++++++++++++ .../streams/TableAggregateBuilderTest.java | 261 ++++++++ 52 files changed, 2317 insertions(+), 1194 deletions(-) rename {ksql-engine/src/main/java/io/confluent/ksql => ksql-execution/src/main/java/io/confluent/ksql/execution}/function/TableAggregationFunction.java (87%) create mode 100644 ksql-execution/src/main/java/io/confluent/ksql/execution/function/UdafUtil.java rename {ksql-engine/src/main/java/io/confluent/ksql => ksql-execution/src/main/java/io/confluent/ksql/execution}/function/udaf/KudafAggregator.java (94%) rename {ksql-engine/src/main/java/io/confluent/ksql => ksql-execution/src/main/java/io/confluent/ksql/execution}/function/udaf/KudafInitializer.java (77%) rename {ksql-engine/src/main/java/io/confluent/ksql => ksql-execution/src/main/java/io/confluent/ksql/execution}/function/udaf/KudafUndoAggregator.java (86%) rename {ksql-engine/src/main/java/io/confluent/ksql => ksql-execution/src/main/java/io/confluent/ksql/execution}/function/udaf/window/WindowSelectMapper.java (86%) create mode 100644 ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamWindowedAggregate.java rename {ksql-parser/src/main/java/io/confluent/ksql/parser/tree => ksql-execution/src/main/java/io/confluent/ksql/execution/windows}/HoppingWindowExpression.java (72%) create mode 100644 ksql-execution/src/main/java/io/confluent/ksql/execution/windows/KsqlWindowExpression.java rename {ksql-parser/src/main/java/io/confluent/ksql/parser/tree => ksql-execution/src/main/java/io/confluent/ksql/execution/windows}/SessionWindowExpression.java (66%) rename {ksql-parser/src/main/java/io/confluent/ksql/parser/tree => ksql-execution/src/main/java/io/confluent/ksql/execution/windows}/TumblingWindowExpression.java (70%) create mode 100644 ksql-execution/src/main/java/io/confluent/ksql/execution/windows/WindowVisitor.java create mode 100644 ksql-execution/src/test/java/io/confluent/ksql/execution/function/UdafUtilTest.java delete mode 100644 ksql-parser/src/main/java/io/confluent/ksql/parser/tree/KsqlWindowExpression.java create mode 100644 ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateBuilderUtils.java create mode 100644 ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateParams.java create mode 100644 ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamAggregateBuilder.java create mode 100644 ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableAggregateBuilder.java create mode 100644 ksql-streams/src/test/java/io/confluent/ksql/execution/streams/AggregateParamsTest.java create mode 100644 ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamAggregateBuilderTest.java create mode 100644 ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableAggregateBuilderTest.java diff --git a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java index 190c5fddc1c5..adac1e186004 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java @@ -27,6 +27,7 @@ import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.QualifiedName; import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.metastore.model.DataSource; import io.confluent.ksql.parser.DefaultTraversalVisitor; @@ -38,7 +39,6 @@ import io.confluent.ksql.parser.tree.GroupingElement; import io.confluent.ksql.parser.tree.Join; import io.confluent.ksql.parser.tree.JoinOn; -import io.confluent.ksql.parser.tree.KsqlWindowExpression; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.Select; import io.confluent.ksql.parser.tree.SelectItem; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/GeneratedTableAggregateFunction.java b/ksql-engine/src/main/java/io/confluent/ksql/function/GeneratedTableAggregateFunction.java index 95f2f20e15c4..0a9b38e3ae0a 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/GeneratedTableAggregateFunction.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/GeneratedTableAggregateFunction.java @@ -15,6 +15,7 @@ package io.confluent.ksql.function; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.udaf.TableUdaf; import java.util.List; import java.util.Optional; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/count/CountKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/count/CountKudaf.java index b93089c5e746..3d1dd8a8713c 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/count/CountKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/count/CountKudaf.java @@ -15,10 +15,10 @@ package io.confluent.ksql.function.udaf.count; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.AggregateFunctionArguments; import io.confluent.ksql.function.BaseAggregateFunction; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; import java.util.Collections; import java.util.List; import java.util.function.Function; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DecimalSumKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DecimalSumKudaf.java index 2e0768ee4d75..84f37caa6777 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DecimalSumKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DecimalSumKudaf.java @@ -15,10 +15,10 @@ package io.confluent.ksql.function.udaf.sum; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.AggregateFunctionArguments; import io.confluent.ksql.function.BaseAggregateFunction; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; import io.confluent.ksql.util.DecimalUtil; import java.math.BigDecimal; import java.math.MathContext; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DoubleSumKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DoubleSumKudaf.java index 5b806ceb3dd3..261bb774b88a 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DoubleSumKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DoubleSumKudaf.java @@ -15,10 +15,10 @@ package io.confluent.ksql.function.udaf.sum; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.AggregateFunctionArguments; import io.confluent.ksql.function.BaseAggregateFunction; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; import java.util.Collections; import java.util.function.Function; import org.apache.kafka.connect.data.Schema; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/IntegerSumKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/IntegerSumKudaf.java index d0d37d080c9e..82c972d25f8f 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/IntegerSumKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/IntegerSumKudaf.java @@ -15,10 +15,10 @@ package io.confluent.ksql.function.udaf.sum; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.AggregateFunctionArguments; import io.confluent.ksql.function.BaseAggregateFunction; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; import java.util.Collections; import java.util.function.Function; import org.apache.kafka.connect.data.Schema; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/LongSumKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/LongSumKudaf.java index a1627a627c63..1a263c2dd67d 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/LongSumKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/LongSumKudaf.java @@ -15,10 +15,10 @@ package io.confluent.ksql.function.udaf.sum; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.AggregateFunctionArguments; import io.confluent.ksql.function.BaseAggregateFunction; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; import java.util.Collections; import java.util.function.Function; import org.apache.kafka.connect.data.Schema; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowEndKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowEndKudaf.java index b76eccb29226..482abe7d8d8f 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowEndKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowEndKudaf.java @@ -15,6 +15,7 @@ package io.confluent.ksql.function.udaf.window; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; import io.confluent.ksql.function.udaf.TableUdaf; import io.confluent.ksql.function.udaf.UdafDescription; import io.confluent.ksql.function.udaf.UdafFactory; @@ -37,7 +38,7 @@ private WindowEndKudaf() { } static String getFunctionName() { - return "WindowEnd"; + return WindowSelectMapper.WINDOW_END_NAME; } @UdafFactory(description = "Extracts the window end time") diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowStartKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowStartKudaf.java index 278cb7dfffac..c8765548373b 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowStartKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowStartKudaf.java @@ -15,6 +15,7 @@ package io.confluent.ksql.function.udaf.window; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; import io.confluent.ksql.function.udaf.TableUdaf; import io.confluent.ksql.function.udaf.UdafDescription; import io.confluent.ksql.function.udaf.UdafFactory; @@ -37,7 +38,7 @@ private WindowStartKudaf() { } static String getFunctionName() { - return "WindowStart"; + return WindowSelectMapper.WINDOW_START_NAME; } @UdafFactory(description = "Extracts the window start time") diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java index d4f313690e0a..66b226c76c5e 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java @@ -19,7 +19,6 @@ import static java.util.Objects.requireNonNull; import com.google.common.collect.ImmutableList; -import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.expression.tree.Expression; @@ -28,12 +27,10 @@ import io.confluent.ksql.execution.expression.tree.QualifiedName; import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor; +import io.confluent.ksql.execution.function.UdafUtil; import io.confluent.ksql.execution.plan.SelectExpression; -import io.confluent.ksql.execution.util.ExpressionTypeManager; -import io.confluent.ksql.function.AggregateFunctionArguments; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.udaf.KudafInitializer; import io.confluent.ksql.materialization.MaterializationInfo; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.parser.rewrite.ExpressionTreeRewriter; @@ -41,11 +38,9 @@ import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; -import io.confluent.ksql.schema.ksql.PhysicalSchema; import io.confluent.ksql.schema.ksql.SchemaConverters; import io.confluent.ksql.schema.ksql.SchemaConverters.ConnectToSqlTypeConverter; import io.confluent.ksql.schema.ksql.types.SqlType; -import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.services.KafkaTopicClient; import io.confluent.ksql.structured.SchemaKGroupedStream; @@ -66,8 +61,6 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; -import org.apache.kafka.common.serialization.Serde; -import org.apache.kafka.connect.data.Schema; public class AggregateNode extends PlanNode { @@ -234,54 +227,43 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { builder ); - // Aggregate computations - final KudafInitializer initializer = new KudafInitializer(requiredColumns.size()); - - final Map aggValToFunctionMap = createAggValToFunctionMap( - aggregateArgExpanded, - initializer, - requiredColumns.size(), - builder.getFunctionRegistry(), - internalSchema - ); + final List functionsWithInternalIdentifiers = functionList.stream() + .map( + fc -> new FunctionCall( + fc.getName(), + internalSchema.getInternalArgsExpressionList(fc.getArguments()) + ) + ) + .collect(Collectors.toList()); // This is the schema of the aggregation change log topic and associated state store. // It contains all columns from prepareSchema and columns for any aggregating functions // It uses internal column names, e.g. KSQL_INTERNAL_COL_0 and KSQL_AGG_VARIABLE_0 final LogicalSchema aggregationSchema = buildLogicalSchema( prepareSchema, - aggValToFunctionMap, + functionsWithInternalIdentifiers, + builder.getFunctionRegistry(), true ); final QueryContext.Stacker aggregationContext = contextStacker.push(AGGREGATION_OP_NAME); - final Serde aggValueGenericRowSerde = builder.buildValueSerde( - valueFormat.getFormatInfo(), - PhysicalSchema.from(aggregationSchema, SerdeOption.none()), - aggregationContext.getQueryContext() - ); - - final List functionsWithInternalIdentifiers = functionList.stream() - .map(internalSchema::resolveToInternal) - .map(FunctionCall.class::cast) - .collect(Collectors.toList()); - final LogicalSchema outputSchema = buildLogicalSchema( prepareSchema, - aggValToFunctionMap, - false); + functionsWithInternalIdentifiers, + builder.getFunctionRegistry(), + false + ); SchemaKTable aggregated = schemaKGroupedStream.aggregate( outputSchema, - initializer, + aggregationSchema, requiredColumns.size(), functionsWithInternalIdentifiers, - aggValToFunctionMap, getWindowExpression(), valueFormat, - aggValueGenericRowSerde, - aggregationContext + aggregationContext, + builder ); final Optional havingExpression = Optional.ofNullable(havingExpressions) @@ -316,61 +298,12 @@ protected int getPartitions(final KafkaTopicClient kafkaTopicClient) { return source.getPartitions(kafkaTopicClient); } - private Map createAggValToFunctionMap( - final SchemaKStream aggregateArgExpanded, - final KudafInitializer initializer, - final int initialUdafIndex, - final FunctionRegistry functionRegistry, - final InternalSchema internalSchema - ) { - int udafIndexInAggSchema = initialUdafIndex; - final Map aggValToAggFunctionMap = new HashMap<>(); - for (final FunctionCall functionCall : functionList) { - final KsqlAggregateFunction aggregateFunction = getAggregateFunction( - functionRegistry, - internalSchema, - functionCall, aggregateArgExpanded.getSchema()); - - aggValToAggFunctionMap.put(udafIndexInAggSchema++, aggregateFunction); - initializer.addAggregateIntializer(aggregateFunction.getInitialValueSupplier()); - } - return aggValToAggFunctionMap; - } - - @SuppressWarnings("deprecation") // Need to migrate away from Connect Schema use. - private static KsqlAggregateFunction getAggregateFunction( - final FunctionRegistry functionRegistry, - final InternalSchema internalSchema, - final FunctionCall functionCall, - final LogicalSchema schema - ) { - try { - final ExpressionTypeManager expressionTypeManager = - new ExpressionTypeManager(schema, functionRegistry); - final List functionArgs = internalSchema.getInternalArgsExpressionList( - functionCall.getArguments()); - final Schema expressionType = expressionTypeManager.getExpressionSchema(functionArgs.get(0)); - final KsqlAggregateFunction aggregateFunctionInfo = functionRegistry - .getAggregate(functionCall.getName().toString(), expressionType); - - final List args = functionArgs.stream() - .map(Expression::toString) - .collect(Collectors.toList()); - - final int udafIndex = Integer - .parseInt(args.get(0).substring(INTERNAL_COLUMN_NAME_PREFIX.length())); - - return aggregateFunctionInfo.getInstance(new AggregateFunctionArguments(udafIndex, args)); - } catch (final Exception e) { - throw new KsqlException("Failed to create aggregate function: " + functionCall, e); - } - } - private LogicalSchema buildLogicalSchema( final LogicalSchema inputSchema, - final Map aggregateFunctions, - final boolean useAggregate) { - + final List aggregations, + final FunctionRegistry functionRegistry, + final boolean useAggregate + ) { final LogicalSchema.Builder schemaBuilder = LogicalSchema.builder(); final List cols = inputSchema.value(); @@ -382,18 +315,13 @@ private LogicalSchema buildLogicalSchema( final ConnectToSqlTypeConverter converter = SchemaConverters.connectToSqlConverter(); - for (int idx = 0; idx < aggregateFunctions.size(); idx++) { - - final KsqlAggregateFunction aggregateFunction = aggregateFunctions - .get(requiredColumns.size() + idx); - - final String colName = AggregateExpressionRewriter.AGGREGATE_FUNCTION_VARIABLE_PREFIX + idx; - SqlType fieldType = null; - if (useAggregate) { - fieldType = converter.toSqlType(aggregateFunction.getAggregateType()); - } else { - fieldType = converter.toSqlType(aggregateFunction.getReturnType()); - } + for (int i = 0; i < aggregations.size(); i++) { + final KsqlAggregateFunction aggregateFunction = + UdafUtil.resolveAggregateFunction(functionRegistry, aggregations.get(i), inputSchema); + final String colName = AggregateExpressionRewriter.AGGREGATE_FUNCTION_VARIABLE_PREFIX + i; + final SqlType fieldType = converter.toSqlType( + useAggregate ? aggregateFunction.getAggregateType() : aggregateFunction.getReturnType() + ); schemaBuilder.valueColumn(colName, fieldType); } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedStream.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedStream.java index 7b2a99aef347..637d851766ea 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedStream.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedStream.java @@ -16,21 +16,19 @@ package io.confluent.ksql.structured; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.plan.ExecutionStep; import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.StreamAggregate; +import io.confluent.ksql.execution.plan.StreamWindowedAggregate; import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; -import io.confluent.ksql.execution.streams.StreamsUtil; +import io.confluent.ksql.execution.streams.StreamAggregateBuilder; import io.confluent.ksql.function.FunctionRegistry; -import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.UdafAggregator; -import io.confluent.ksql.function.udaf.KudafAggregator; -import io.confluent.ksql.function.udaf.window.WindowSelectMapper; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.model.WindowType; -import io.confluent.ksql.parser.tree.KsqlWindowExpression; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.serde.Format; @@ -43,15 +41,11 @@ import io.confluent.ksql.util.KsqlConfig; import java.time.Duration; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.Optional; -import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; import org.apache.kafka.streams.kstream.KGroupedStream; import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; import org.apache.kafka.streams.kstream.Windowed; public class SchemaKGroupedStream { @@ -122,54 +116,61 @@ public ExecutionStep> getSourceStep() { @SuppressWarnings("unchecked") public SchemaKTable aggregate( final LogicalSchema outputSchema, - final Initializer initializer, + final LogicalSchema aggregateSchema, final int nonFuncColumnCount, final List aggregations, - final Map aggValToFunctionMap, final WindowExpression windowExpression, final ValueFormat valueFormat, - final Serde topicValueSerDe, - final QueryContext.Stacker contextStacker + final QueryContext.Stacker contextStacker, + final KsqlQueryBuilder queryBuilder ) { - throwOnValueFieldCountMismatch(outputSchema, nonFuncColumnCount, aggValToFunctionMap); + throwOnValueFieldCountMismatch(outputSchema, nonFuncColumnCount, aggregations); + final ExecutionStep> step; final KTable table; final KeySerde newKeySerde; final KeyFormat keyFormat; if (windowExpression != null) { keyFormat = getKeyFormat(windowExpression); newKeySerde = getKeySerde(windowExpression); - - table = aggregateWindowed( - initializer, + final StreamWindowedAggregate aggregate = ExecutionStepFactory.streamWindowedAggregate( + contextStacker, + sourceStep, + outputSchema, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), nonFuncColumnCount, - aggValToFunctionMap, - windowExpression, - topicValueSerDe, - contextStacker + aggregations, + aggregateSchema, + windowExpression.getKsqlWindowExpression() + ); + step = aggregate; + table = StreamAggregateBuilder.build( + kgroupedStream, + aggregate, + queryBuilder, + materializedFactory ); } else { keyFormat = this.keyFormat; newKeySerde = keySerde; - - table = aggregateNonWindowed( - initializer, + final StreamAggregate aggregate = ExecutionStepFactory.streamAggregate( + contextStacker, + sourceStep, + outputSchema, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), nonFuncColumnCount, - aggValToFunctionMap, - topicValueSerDe, - contextStacker + aggregations, + aggregateSchema + ); + step = aggregate; + table = StreamAggregateBuilder.build( + kgroupedStream, + aggregate, + queryBuilder, + materializedFactory ); } - final ExecutionStep step = ExecutionStepFactory.streamAggregate( - contextStacker, - sourceStep, - outputSchema, - Formats.of(keyFormat, valueFormat, SerdeOption.none()), - nonFuncColumnCount, - aggregations - ); - return new SchemaKTable( table, step, @@ -183,61 +184,6 @@ public SchemaKTable aggregate( ); } - @SuppressWarnings("unchecked") - private KTable aggregateNonWindowed( - final Initializer initializer, - final int nonFuncColumnCount, - final Map indexToFunctionMap, - final Serde topicValueSerDe, - final QueryContext.Stacker contextStacker - ) { - final UdafAggregator aggregator = new KudafAggregator(nonFuncColumnCount, indexToFunctionMap); - - final Materialized materialized = materializedFactory.create( - keySerde, - topicValueSerDe, - StreamsUtil.buildOpName(contextStacker.getQueryContext()) - ); - - final KTable aggTable = kgroupedStream.aggregate(initializer, aggregator, materialized); - - return getAggregationResult(aggTable, aggregator); - } - - @SuppressWarnings("unchecked") - private KTable aggregateWindowed( - final Initializer initializer, - final int nonFuncColumnCount, - final Map indexToFunctionMap, - final WindowExpression windowExpression, - final Serde topicValueSerDe, - final QueryContext.Stacker contextStacker - ) { - final UdafAggregator aggregator = new KudafAggregator(nonFuncColumnCount, indexToFunctionMap); - - final KsqlWindowExpression ksqlWindowExpression = windowExpression.getKsqlWindowExpression(); - - final Materialized materialized = materializedFactory.create( - keySerde, - topicValueSerDe, - StreamsUtil.buildOpName(contextStacker.getQueryContext()) - ); - - final KTable, GenericRow> aggKtable = ksqlWindowExpression.applyAggregate( - kgroupedStream, initializer, aggregator, materialized); - - // Apply the mapper before window_start and window_end functions that return null if a - // record is not part of the window. - final KTable reducedTable = getAggregationResult(aggKtable, aggregator); - - final WindowSelectMapper windowSelectMapper = new WindowSelectMapper(indexToFunctionMap); - if (!windowSelectMapper.hasSelects()) { - return reducedTable; - } - - return reducedTable.mapValues(windowSelectMapper); - } - private KeyFormat getKeyFormat(final WindowExpression windowExpression) { if (ksqlConfig.getBoolean(KsqlConfig.KSQL_WINDOWED_SESSION_KEY_LEGACY_CONFIG)) { return KeyFormat.windowed( @@ -254,11 +200,6 @@ private KeyFormat getKeyFormat(final WindowExpression windowExpression) { ); } - @SuppressWarnings("unchecked") - private KTable getAggregationResult(final KTable table, final UdafAggregator aggregator) { - return table.mapValues(aggregator.getResultMapper()); - } - private KeySerde> getKeySerde(final WindowExpression windowExpression) { if (ksqlConfig.getBoolean(KsqlConfig.KSQL_WINDOWED_SESSION_KEY_LEGACY_CONFIG)) { return keySerde.rebind(WindowInfo.of( @@ -273,10 +214,9 @@ private KeySerde> getKeySerde(final WindowExpression windowExpr static void throwOnValueFieldCountMismatch( final LogicalSchema aggregateSchema, final int nonFuncColumnCount, - final Map aggValToFunctionMap + final List aggregateFunctions ) { - final int nonAggColumnCount = aggValToFunctionMap.size(); - final int totalColumnCount = nonAggColumnCount + nonFuncColumnCount; + final int totalColumnCount = aggregateFunctions.size() + nonFuncColumnCount; final int valueColumnCount = aggregateSchema.value().size(); if (valueColumnCount != totalColumnCount) { diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedTable.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedTable.java index 1f9976f31074..4c5b0928ea04 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedTable.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedTable.java @@ -16,18 +16,19 @@ package io.confluent.ksql.structured; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.function.TableAggregationFunction; +import io.confluent.ksql.execution.function.UdafUtil; import io.confluent.ksql.execution.plan.ExecutionStep; import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.TableAggregate; import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; -import io.confluent.ksql.execution.streams.StreamsUtil; +import io.confluent.ksql.execution.streams.TableAggregateBuilder; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; -import io.confluent.ksql.function.udaf.KudafAggregator; -import io.confluent.ksql.function.udaf.KudafUndoAggregator; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.schema.ksql.LogicalSchema; @@ -38,15 +39,10 @@ import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; -import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; import org.apache.kafka.streams.kstream.KGroupedTable; -import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; public class SchemaKGroupedTable extends SchemaKGroupedStream { private final KGroupedTable kgroupedTable; @@ -109,24 +105,24 @@ public ExecutionStep> getSourceTableStep() { @Override public SchemaKTable aggregate( final LogicalSchema outputSchema, - final Initializer initializer, + final LogicalSchema aggregateSchema, final int nonFuncColumnCount, final List aggregations, - final Map aggValToFunctionMap, final WindowExpression windowExpression, final ValueFormat valueFormat, - final Serde topicValueSerDe, - final QueryContext.Stacker contextStacker + final QueryContext.Stacker contextStacker, + final KsqlQueryBuilder queryBuilder ) { if (windowExpression != null) { throw new KsqlException("Windowing not supported for table aggregations."); } - throwOnValueFieldCountMismatch(outputSchema, nonFuncColumnCount, aggValToFunctionMap); + throwOnValueFieldCountMismatch(outputSchema, nonFuncColumnCount, aggregations); - final List unsupportedFunctionNames = aggValToFunctionMap.values() - .stream() - .filter(function -> !(function instanceof TableAggregationFunction)) + final List unsupportedFunctionNames = aggregations.stream() + .map(call -> UdafUtil.resolveAggregateFunction( + queryBuilder.getFunctionRegistry(), call, sourceTableStep.getSchema()) + ).filter(function -> !(function instanceof TableAggregationFunction)) .map(KsqlAggregateFunction::getFunctionName) .collect(Collectors.toList()); if (!unsupportedFunctionNames.isEmpty()) { @@ -136,46 +132,23 @@ public SchemaKTable aggregate( String.join(", ", unsupportedFunctionNames))); } - final KudafAggregator aggregator = new KudafAggregator( - nonFuncColumnCount, aggValToFunctionMap); - - final Map aggValToUndoFunctionMap = - aggValToFunctionMap.keySet() - .stream() - .collect( - Collectors.toMap( - k -> k, - k -> ((TableAggregationFunction) aggValToFunctionMap.get(k)))); - - final KudafUndoAggregator subtractor = new KudafUndoAggregator( - nonFuncColumnCount, aggValToUndoFunctionMap); - - final Materialized materialized = materializedFactory.create( - keySerde, - topicValueSerDe, - StreamsUtil.buildOpName(contextStacker.getQueryContext()) - ); - - final KTable aggKtable = kgroupedTable.aggregate( - initializer, - aggregator, - subtractor, - materialized); - - final ExecutionStep step = ExecutionStepFactory.tableAggregate( + final TableAggregate step = ExecutionStepFactory.tableAggregate( contextStacker, sourceTableStep, outputSchema, Formats.of(keyFormat, valueFormat, SerdeOption.none()), nonFuncColumnCount, - aggregations + aggregations, + aggregateSchema ); - final KTable outputTable = aggKtable.mapValues( - aggregator.getResultMapper()); - return new SchemaKTable<>( - outputTable, + TableAggregateBuilder.build( + kgroupedTable, + step, + queryBuilder, + materializedFactory + ), step, keyFormat, keySerde, diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/KudafUndoAggregatorTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/KudafUndoAggregatorTest.java index 5433c574f11d..96a4dd5008d2 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/KudafUndoAggregatorTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/KudafUndoAggregatorTest.java @@ -20,7 +20,8 @@ import static org.junit.Assert.assertThat; import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.udaf.KudafUndoAggregator; +import io.confluent.ksql.execution.function.TableAggregationFunction; +import io.confluent.ksql.execution.function.udaf.KudafUndoAggregator; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfCompilerTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfCompilerTest.java index 2fdc8737c991..22a82a39e3f4 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfCompilerTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfCompilerTest.java @@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableList; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.udaf.TestUdaf; import io.confluent.ksql.function.udaf.Udaf; import io.confluent.ksql.util.KsqlException; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/sum/BaseSumKudafTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/sum/BaseSumKudafTest.java index cc8373211434..aef5c8ff2810 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/sum/BaseSumKudafTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/sum/BaseSumKudafTest.java @@ -18,7 +18,7 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.junit.Assert.assertThat; -import io.confluent.ksql.function.TableAggregationFunction; +import io.confluent.ksql.execution.function.TableAggregationFunction; import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/window/WindowSelectMapperTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/window/WindowSelectMapperTest.java index dba6ef59cf22..aeea8ba6dd97 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/window/WindowSelectMapperTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/window/WindowSelectMapperTest.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; import io.confluent.ksql.function.KsqlAggregateFunction; import java.util.ArrayList; import java.util.Arrays; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java index 4d593585a7c9..49be0faa5888 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java @@ -15,40 +15,33 @@ package io.confluent.ksql.structured; -import static java.util.Collections.emptyList; -import static java.util.Collections.emptyMap; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.sameInstance; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.expression.tree.QualifiedName; +import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; import io.confluent.ksql.execution.plan.ExecutionStep; import io.confluent.ksql.execution.plan.Formats; import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; -import io.confluent.ksql.execution.streams.StreamsUtil; +import io.confluent.ksql.execution.windows.SessionWindowExpression; import io.confluent.ksql.function.FunctionRegistry; -import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.function.InternalFunctionRegistry; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.model.WindowType; -import io.confluent.ksql.parser.tree.KsqlWindowExpression; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.query.QueryId; -import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.KeyFormat; @@ -57,21 +50,17 @@ import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.serde.WindowInfo; import io.confluent.ksql.util.KsqlConfig; -import java.time.Duration; -import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Optional; -import java.util.stream.Collectors; -import java.util.stream.IntStream; +import java.util.concurrent.TimeUnit; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; import org.apache.kafka.streams.kstream.KGroupedStream; import org.apache.kafka.streams.kstream.KTable; import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.SessionWindowedKStream; +import org.apache.kafka.streams.kstream.SessionWindows; import org.apache.kafka.streams.kstream.ValueMapper; -import org.apache.kafka.streams.kstream.ValueMapperWithKey; import org.apache.kafka.streams.kstream.Windowed; import org.junit.Before; import org.junit.Test; @@ -82,39 +71,45 @@ @SuppressWarnings("unchecked") @RunWith(MockitoJUnitRunner.class) public class SchemaKGroupedStreamTest { - @Mock - private LogicalSchema aggregateSchema; + private static final LogicalSchema IN_SCHEMA = LogicalSchema.builder() + .valueColumn("IN0", SqlTypes.STRING) + .valueColumn("IN1", SqlTypes.INTEGER) + .build(); + private static final LogicalSchema AGG_SCHEMA = LogicalSchema.builder() + .valueColumn("IN0", SqlTypes.STRING) + .valueColumn("AGG0", SqlTypes.BIGINT) + .build(); + private static final LogicalSchema OUT_SCHEMA = LogicalSchema.builder() + .valueColumn("IN0", SqlTypes.STRING) + .valueColumn("OUT0", SqlTypes.STRING) + .build(); + private static final FunctionCall AGG = new FunctionCall( + QualifiedName.of("SUM"), + ImmutableList.of(new QualifiedNameReference(QualifiedName.of("IN1"))) + ); + private static final KsqlWindowExpression KSQL_WINDOW_EXP = new SessionWindowExpression( + 100, TimeUnit.SECONDS + ); + @Mock private KGroupedStream groupedStream; @Mock + private SessionWindowedKStream sessionWindowedStream; + @Mock private KeyField keyField; @Mock private List sourceStreams; @Mock private KsqlConfig config; @Mock - private FunctionRegistry funcRegistry; - @Mock - private Initializer initializer; - @Mock private Serde topicValueSerDe; @Mock - private KsqlAggregateFunction windowStartFunc; - @Mock - private KsqlAggregateFunction windowEndFunc; - @Mock - private KsqlAggregateFunction otherFunc; - @Mock private FunctionCall aggCall; @Mock private KTable table; @Mock - private KTable table2; - @Mock private WindowExpression windowExp; @Mock - private KsqlWindowExpression ksqlWindowExp; - @Mock private MaterializedFactory materializedFactory; @Mock private Materialized materialized; @@ -123,15 +118,18 @@ public class SchemaKGroupedStreamTest { @Mock private KeySerde> windowedKeySerde; @Mock - private Column field; - @Mock private ExecutionStep sourceStep; @Mock private KeyFormat keyFormat; @Mock private ValueFormat valueFormat; + @Mock + private KsqlQueryBuilder builder; + + private final FunctionRegistry functionRegistry = new InternalFunctionRegistry(); private final QueryContext.Stacker queryContext = new QueryContext.Stacker(new QueryId("query")).push("node"); + private SchemaKGroupedStream schemaGroupedStream; @Before @@ -144,328 +142,56 @@ public void setUp() { keyField, sourceStreams, config, - funcRegistry, + functionRegistry, materializedFactory ); - - when(windowStartFunc.getFunctionName()).thenReturn("WindowStart"); - when(windowEndFunc.getFunctionName()).thenReturn("WindowEnd"); - when(otherFunc.getFunctionName()).thenReturn("NotWindowStartFunc"); - when(windowExp.getKsqlWindowExpression()).thenReturn(ksqlWindowExp); + when(sourceStep.getSchema()).thenReturn(IN_SCHEMA); + when(windowExp.getKsqlWindowExpression()).thenReturn(KSQL_WINDOW_EXP); when(config.getBoolean(KsqlConfig.KSQL_WINDOWED_SESSION_KEY_LEGACY_CONFIG)).thenReturn(false); when(materializedFactory.create(any(), any(), any())).thenReturn(materialized); - - when(ksqlWindowExp.getWindowInfo()) - .thenReturn(WindowInfo.of(WindowType.SESSION, Optional.empty())); - + when(builder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(builder.buildValueSerde(any(), any(), any())).thenReturn(topicValueSerDe); + when(builder.getFunctionRegistry()).thenReturn(functionRegistry); when(keySerde.rebind(any(WindowInfo.class))).thenReturn(windowedKeySerde); - - when(aggregateSchema.value()).thenReturn(mock(List.class)); - - when(ksqlWindowExp.applyAggregate(any(), any(), any(), any())).thenReturn(table); when(table.mapValues(any(ValueMapper.class))).thenReturn(table); } @Test - public void shouldNoUseSelectMapperForNonWindowed() { + public void shouldReturnKTableWithOutputSchema() { // Given: - final Map invalidWindowFuncs = ImmutableMap.of( - 2, windowStartFunc, 4, windowEndFunc); + when(groupedStream.aggregate(any(), any(), any())).thenReturn(table); - // When: - assertDoesNotInstallWindowSelectMapper(null, invalidWindowFuncs); - } - - @Test - public void shouldNotUseSelectMapperForWindowedWithoutWindowSelects() { - // Given: - final Map nonWindowFuncs = ImmutableMap.of(2, otherFunc); - - // When: - assertDoesNotInstallWindowSelectMapper(windowExp, nonWindowFuncs); - } - - @Test - public void shouldUseSelectMapperForWindowedWithWindowStart() { - // Given: - Map funcMapWithWindowStart = ImmutableMap.of( - 0, otherFunc, 1, windowStartFunc); - - // Then: - assertDoesInstallWindowSelectMapper(funcMapWithWindowStart); - } - - @Test - public void shouldUseSelectMapperForWindowedWithWindowEnd() { - // Given: - Map funcMapWithWindowEnd = ImmutableMap.of( - 0, windowEndFunc, 1, otherFunc); - - // Then: - assertDoesInstallWindowSelectMapper(funcMapWithWindowEnd); - } - - @Test - public void shouldSupportSessionWindowedKey() { - // Given: - final WindowInfo windowInfo = WindowInfo.of(WindowType.SESSION, Optional.empty()); - when(ksqlWindowExp.getWindowInfo()).thenReturn(windowInfo); - - // When: - final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - emptyMap(), - windowExp, - valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - verify(keySerde).rebind(windowInfo); - assertThat(result.getKeySerde(), is(windowedKeySerde)); - } - - @Test - public void shouldSupportHoppingWindowedKey() { - // Given: - final WindowInfo windowInfo = WindowInfo - .of(WindowType.HOPPING, Optional.of(Duration.ofMillis(10))); - when(ksqlWindowExp.getWindowInfo()).thenReturn(windowInfo); - - // When: - final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - emptyMap(), - windowExp, - valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - verify(keySerde).rebind(windowInfo); - assertThat(result.getKeySerde(), is(windowedKeySerde)); - } - - @Test - public void shouldSupportTumblingWindowedKey() { - // Given: - final WindowInfo windowInfo = WindowInfo - .of(WindowType.TUMBLING, Optional.of(Duration.ofMillis(10))); - - when(ksqlWindowExp.getWindowInfo()).thenReturn(windowInfo); // When: final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - emptyMap(), - windowExp, - valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - verify(keySerde).rebind(windowInfo); - assertThat(result.getKeySerde(), is(windowedKeySerde)); - } - - @Test - public void shouldUseTimeWindowKeySerdeForWindowedIfLegacyConfig() { - // Given: - when(config.getBoolean(KsqlConfig.KSQL_WINDOWED_SESSION_KEY_LEGACY_CONFIG)) - .thenReturn(true); - - // When: - final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - emptyMap(), - windowExp, - valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - verify(keySerde) - .rebind(WindowInfo.of(WindowType.TUMBLING, Optional.of(Duration.ofMillis(Long.MAX_VALUE)))); - assertThat(result.getKeySerde(), is(windowedKeySerde)); - } - - private void assertDoesNotInstallWindowSelectMapper( - final WindowExpression windowExp, - final Map funcMap) { - - // Given: - if (windowExp != null) { - when(ksqlWindowExp.applyAggregate(any(), any(), any(), any())) - .thenReturn(table); - } else { - when(groupedStream.aggregate(any(), any(), any())) - .thenReturn(table); - } - givenAggregateSchemaFieldCount(funcMap.size()); - - // When: - final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - funcMap, - windowExp, - valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - assertThat(result.getKtable(), is(sameInstance(table))); - verify(table, never()).mapValues(any(ValueMapperWithKey.class)); - } - - private void assertDoesInstallWindowSelectMapper( - final Map funcMap) { - - // Given: - when(table.mapValues(any(ValueMapperWithKey.class))).thenReturn(table2); - givenAggregateSchemaFieldCount(funcMap.size()); - - // When: - final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - funcMap, - windowExp, - valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - assertThat(result.getKtable(), is(sameInstance(table2))); - verify(table, times(1)).mapValues(any(ValueMapperWithKey.class)); - } - - @SuppressWarnings("unchecked") - private Materialized whenMaterializedFactoryCreates() { - final Materialized materialized = mock(Materialized.class); - when(materializedFactory.create(any(), any(), any())).thenReturn(materialized); - return materialized; - } - - @SuppressWarnings("unchecked") - @Test - public void shouldUseMaterializedFactoryForStateStore() { - // Given: - final Materialized materialized = whenMaterializedFactoryCreates(); - final KTable mockKTable = mock(KTable.class); - when(groupedStream.aggregate(any(), any(), same(materialized))).thenReturn(mockKTable); - - // When: - schemaGroupedStream.aggregate( - aggregateSchema, - () -> null, - 0, - emptyList(), - Collections.emptyMap(), + OUT_SCHEMA, + AGG_SCHEMA, + 1, + ImmutableList.of(AGG), null, valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - verify(materializedFactory) - .create( - same(keySerde), - same(topicValueSerDe), - eq(StreamsUtil.buildOpName(queryContext.getQueryContext()))); - verify(groupedStream, times(1)).aggregate(any(), any(), same(materialized)); - } - - @SuppressWarnings("unchecked") - @Test - public void shouldUseMaterializedFactoryWindowedStateStore() { - // Given: - final Materialized materialized = whenMaterializedFactoryCreates(); - when(ksqlWindowExp.applyAggregate(any(), any(), any(), same(materialized))) - .thenReturn(table); - - // When: - schemaGroupedStream.aggregate( - aggregateSchema, - () -> null, - 0, - emptyList(), - Collections.emptyMap(), - windowExp, - valueFormat, - topicValueSerDe, - queryContext); - - // Then: - verify(materializedFactory) - .create( - same(keySerde), - same(topicValueSerDe), - eq(StreamsUtil.buildOpName(queryContext.getQueryContext()))); - verify(ksqlWindowExp, times(1)).applyAggregate(any(), any(), any(), same(materialized)); - } - - @Test - public void shouldReturnKTableWithAggregateSchema() { - // When: - final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - emptyMap(), - windowExp, - valueFormat, - topicValueSerDe, - queryContext + queryContext, + builder ); // Then: - assertThat(result.getSchema(), is(aggregateSchema)); + assertThat(result.getSchema(), is(OUT_SCHEMA)); } @Test public void shouldBuildStepForAggregate() { // Given: - final Map functions = ImmutableMap.of(1, otherFunc); - when(aggregateSchema.value()) - .thenReturn(ImmutableList.of(mock(Column.class), mock(Column.class))); - when(groupedStream.aggregate(any(), any(), any())) - .thenReturn(table); + when(groupedStream.aggregate(any(), any(), any())).thenReturn(table); // When: final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, + OUT_SCHEMA, + AGG_SCHEMA, 1, - ImmutableList.of(aggCall), - functions, + ImmutableList.of(AGG), null, valueFormat, - topicValueSerDe, - queryContext + queryContext, + builder ); // Then: @@ -475,28 +201,34 @@ public void shouldBuildStepForAggregate() { ExecutionStepFactory.streamAggregate( queryContext, schemaGroupedStream.getSourceStep(), - aggregateSchema, + OUT_SCHEMA, Formats.of(keyFormat, valueFormat, SerdeOption.none()), 1, - ImmutableList.of(aggCall) + ImmutableList.of(AGG), + AGG_SCHEMA ) ) ); + assertThat(result.getKtable(), is(table)); } @Test - public void shouldBuildStepKeyFormatForWindowedAggregate() { + public void shouldBuildStepForWindowedAggregate() { + // Given: + when(groupedStream.windowedBy(any(SessionWindows.class))).thenReturn(sessionWindowedStream); + when(sessionWindowedStream.aggregate(any(), any(), any(), any())).thenReturn(table); + when(table.mapValues(any(ValueMapper.class))).thenReturn(table); + // When: final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - Collections.emptyMap(), + OUT_SCHEMA, + AGG_SCHEMA, + 1, + ImmutableList.of(AGG), windowExp, valueFormat, - topicValueSerDe, - queryContext + queryContext, + builder ); // Then: @@ -507,47 +239,33 @@ public void shouldBuildStepKeyFormatForWindowedAggregate() { assertThat( result.getSourceTableStep(), equalTo( - ExecutionStepFactory.streamAggregate( + ExecutionStepFactory.streamWindowedAggregate( queryContext, schemaGroupedStream.getSourceStep(), - aggregateSchema, + OUT_SCHEMA, Formats.of(expected, valueFormat, SerdeOption.none()), - 0, - Collections.emptyList() + 1, + ImmutableList.of(AGG), + AGG_SCHEMA, + KSQL_WINDOW_EXP ) ) ); + assertThat(result.getKtable(), is(table)); } @Test(expected = IllegalArgumentException.class) public void shouldThrowOnColumnCountMismatch() { - // Given: - // Agg schema has 2 fields: - givenAggregateSchemaFieldCount(2); - - // Where as params have 1 nonAgg and 2 agg fields: - final Map aggColumns = ImmutableMap.of(2, otherFunc); - // When: schemaGroupedStream.aggregate( - aggregateSchema, - initializer, + OUT_SCHEMA, + AGG_SCHEMA, 2, ImmutableList.of(aggCall), - aggColumns, windowExp, valueFormat, - topicValueSerDe, - queryContext + queryContext, + builder ); } - - private void givenAggregateSchemaFieldCount(final int count) { - final List valueFields = IntStream - .range(0, count) - .mapToObj(i -> field) - .collect(Collectors.toList()); - - when(aggregateSchema.value()).thenReturn(valueFields); - } } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java index 6a8b3137310f..f1ae08a81ce9 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java @@ -15,83 +15,42 @@ package io.confluent.ksql.structured; -import static java.util.Collections.emptyList; -import static java.util.Collections.emptyMap; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import io.confluent.kafka.schemaregistry.client.MockSchemaRegistryClient; -import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; -import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.QualifiedName; import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; -import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; import io.confluent.ksql.execution.plan.ExecutionStep; import io.confluent.ksql.execution.plan.Formats; import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; -import io.confluent.ksql.execution.streams.StreamsUtil; import io.confluent.ksql.function.InternalFunctionRegistry; -import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; -import io.confluent.ksql.function.udaf.KudafInitializer; -import io.confluent.ksql.logging.processing.NoopProcessingLogContext; -import io.confluent.ksql.logging.processing.ProcessingLogContext; -import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.metastore.model.KeyField; -import io.confluent.ksql.metastore.model.KsqlTable; import io.confluent.ksql.parser.tree.WindowExpression; -import io.confluent.ksql.planner.plan.PlanNode; import io.confluent.ksql.query.QueryId; -import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; -import io.confluent.ksql.schema.ksql.PersistenceSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatInfo; -import io.confluent.ksql.serde.GenericRowSerDe; import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.KeySerde; import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.serde.ValueFormat; -import io.confluent.ksql.testutils.AnalysisTestUtil; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; -import io.confluent.ksql.util.MetaStoreFixture; -import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import org.apache.kafka.common.serialization.Serde; -import org.apache.kafka.common.serialization.Serdes; -import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.StreamsBuilder; -import org.apache.kafka.streams.kstream.Consumed; -import org.apache.kafka.streams.kstream.Initializer; import org.apache.kafka.streams.kstream.KGroupedTable; import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; import org.apache.kafka.streams.kstream.ValueMapper; -import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -103,16 +62,29 @@ @SuppressWarnings("unchecked") @RunWith(MockitoJUnitRunner.class) public class SchemaKGroupedTableTest { + private static final LogicalSchema IN_SCHEMA = LogicalSchema.builder() + .valueColumn("IN0", SqlTypes.STRING) + .valueColumn("IN1", SqlTypes.INTEGER) + .build(); + private static final LogicalSchema AGG_SCHEMA = LogicalSchema.builder() + .valueColumn("IN0", SqlTypes.STRING) + .valueColumn("AGG0", SqlTypes.BIGINT) + .valueColumn("AGG1", SqlTypes.BIGINT) + .build(); + private static final LogicalSchema OUT_SCHEMA = LogicalSchema.builder() + .valueColumn("IN0", SqlTypes.STRING) + .valueColumn("OUT0", SqlTypes.STRING) + .valueColumn("OUT1", SqlTypes.STRING) + .build(); + private static final FunctionCall MIN = udaf("MIN"); + private static final FunctionCall MAX = udaf("MAX"); + private static final FunctionCall SUM = udaf("SUM"); + private static final FunctionCall COUNT = udaf("COUNT"); + private final KsqlConfig ksqlConfig = new KsqlConfig(Collections.emptyMap()); private final InternalFunctionRegistry functionRegistry = new InternalFunctionRegistry(); - private final ProcessingLogContext processingLogContext = ProcessingLogContext.create(); private final KGroupedTable mockKGroupedTable = mock(KGroupedTable.class); - private final LogicalSchema schema = LogicalSchema.builder() - .valueColumn("GROUPING_COLUMN", SqlTypes.STRING) - .valueColumn("AGG_VALUE", SqlTypes.INTEGER) - .build(); private final MaterializedFactory materializedFactory = mock(MaterializedFactory.class); - private final MetaStore metaStore = MetaStoreFixture.getNewMetaStore(new InternalFunctionRegistry()); private final QueryContext.Stacker queryContext = new QueryContext.Stacker(new QueryId("query")).push("node"); private final ValueFormat valueFormat = ValueFormat.of(FormatInfo.of(Format.JSON)); @@ -124,99 +96,23 @@ public class SchemaKGroupedTableTest { @Mock private KeySerde keySerde; @Mock - private LogicalSchema aggregateSchema; - @Mock - private Initializer initializer; - @Mock - private Serde topicValueSerDe; - @Mock - private FunctionCall aggCall1; - @Mock - private FunctionCall aggCall2; - @Mock - private Column field; - @Mock - private KsqlAggregateFunction otherFunc; - @Mock - private TableAggregationFunction tableFunc; - @Mock private KsqlQueryBuilder queryBuilder; @Mock private KTable table; - private KTable kTable; - private KsqlTable ksqlTable; - @Before public void init() { - ksqlTable = (KsqlTable) metaStore.getSource("TEST2"); - final StreamsBuilder builder = new StreamsBuilder(); - - final Serde rowSerde = GenericRowSerDe.from( - ksqlTable.getKsqlTopic().getValueFormat().getFormatInfo(), - PersistenceSchema.from(ksqlTable.getSchema().valueConnectSchema(), false), - new KsqlConfig(Collections.emptyMap()), - MockSchemaRegistryClient::new, - "", - NoopProcessingLogContext.INSTANCE - ); - - kTable = builder.table( - ksqlTable.getKsqlTopic().getKafkaTopicName(), - Consumed.with(Serdes.String(), rowSerde) - ); - when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); - when(queryBuilder.getKsqlConfig()).thenReturn(ksqlConfig); - - when(aggregateSchema.findValueColumn("GROUPING_COLUMN")) - .thenReturn(Optional.of(Column.of("GROUPING_COLUMN", SqlTypes.STRING))); - - when(aggregateSchema.value()).thenReturn(mock(List.class)); - when(mockKGroupedTable.aggregate(any(), any(), any(), any())).thenReturn(table); when(table.mapValues(any(ValueMapper.class))).thenReturn(table); } private ExecutionStep buildSourceTableStep(final LogicalSchema schema) { final ExecutionStep step = mock(ExecutionStep.class); - when(step.getProperties()).thenReturn( - new DefaultExecutionStepProperties(schema, queryContext.getQueryContext()) - ); when(step.getSchema()).thenReturn(schema); return step; } - private SchemaKGroupedTable buildSchemaKGroupedTableFromQuery( - final String query, - final String...groupByColumns - ) { - when(keySerde.rebind(any(PersistenceSchema.class))).thenReturn(keySerde); - - final PlanNode logicalPlan = AnalysisTestUtil.buildLogicalPlan(ksqlConfig, query, metaStore); - - final SchemaKTable initialSchemaKTable = new SchemaKTable( - kTable, - buildSourceTableStep(logicalPlan.getTheSourceNode().getSchema()), - keyFormat, - keySerde, - logicalPlan.getTheSourceNode().getKeyField(), - new ArrayList<>(), - SchemaKStream.Type.SOURCE, - ksqlConfig, - functionRegistry); - - final List groupByExpressions = - Arrays.stream(groupByColumns) - .map(c -> new QualifiedNameReference(QualifiedName.of("TEST1", c))) - .collect(Collectors.toList()); - - final SchemaKGroupedStream groupedSchemaKTable = initialSchemaKTable.groupBy( - valueFormat, groupByExpressions, queryContext, queryBuilder); - Assert.assertThat(groupedSchemaKTable, instanceOf(SchemaKGroupedTable.class)); - return (SchemaKGroupedTable)groupedSchemaKTable; - } - @Test public void shouldFailWindowedTableAggregation() { // Given: @@ -231,56 +127,39 @@ public void shouldFailWindowedTableAggregation() { // When: groupedTable.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - emptyMap(), + OUT_SCHEMA, + AGG_SCHEMA, + 1, + ImmutableList.of(SUM, COUNT), windowExp, valueFormat, - topicValueSerDe, - queryContext + queryContext, + queryBuilder ); } @Test public void shouldFailUnsupportedAggregateFunction() { - final SchemaKGroupedTable kGroupedTable = buildSchemaKGroupedTableFromQuery( - "SELECT col0, col1, col2 FROM test1 EMIT CHANGES;", "COL1", "COL2"); - final InternalFunctionRegistry functionRegistry = new InternalFunctionRegistry(); - try { - final Map aggValToFunctionMap = new HashMap<>(); - aggValToFunctionMap.put( - 0, functionRegistry.getAggregate("MAX", Schema.OPTIONAL_INT64_SCHEMA)); - aggValToFunctionMap.put( - 1, functionRegistry.getAggregate("MIN", Schema.OPTIONAL_INT64_SCHEMA)); + // Given: + final SchemaKGroupedTable kGroupedTable = + buildSchemaKGroupedTable(mockKGroupedTable, materializedFactory); - givenAggregateSchemaFieldCount(aggValToFunctionMap.size() + 1); + // Then: + expectedException.expect(KsqlException.class); + expectedException.expectMessage( + "The aggregation function(s) (MIN, MAX) cannot be applied to a table."); - kGroupedTable.aggregate( - aggregateSchema, - new KudafInitializer(1), - 1, - ImmutableList.of(aggCall1, aggCall2), - aggValToFunctionMap, - null, - valueFormat, - GenericRowSerDe.from( - FormatInfo.of(Format.JSON, Optional.empty()), - PersistenceSchema.from(ksqlTable.getSchema().valueConnectSchema(), false), - ksqlConfig, - () -> null, - "test", - processingLogContext), - queryContext - ); - Assert.fail("Should fail to build topology for aggregation with unsupported function"); - } catch(final KsqlException e) { - Assert.assertThat( - e.getMessage(), - equalTo( - "The aggregation function(s) (MAX, MIN) cannot be applied to a table.")); - } + // When: + kGroupedTable.aggregate( + OUT_SCHEMA, + AGG_SCHEMA, + 1, + ImmutableList.of(MIN, MAX), + null, + valueFormat, + queryContext, + queryBuilder + ); } private SchemaKGroupedTable buildSchemaKGroupedTable( @@ -289,80 +168,31 @@ private SchemaKGroupedTable buildSchemaKGroupedTable( ) { return new SchemaKGroupedTable( kGroupedTable, - buildSourceTableStep(schema), + buildSourceTableStep(IN_SCHEMA), keyFormat, keySerde, - KeyField.of(schema.value().get(0).name(), schema.value().get(0)), + KeyField.of(IN_SCHEMA.value().get(0).name(), IN_SCHEMA.value().get(0)), Collections.emptyList(), ksqlConfig, functionRegistry, materializedFactory); } - @Test - public void shouldUseMaterializedFactoryForStateStore() { - // Given: - final Serde valueSerde = mock(Serde.class); - final Materialized materialized = MaterializedFactory.create(ksqlConfig).create( - Serdes.String(), - valueSerde, - StreamsUtil.buildOpName(queryContext.getQueryContext())); - - when(materializedFactory.create(any(), any(), any())).thenReturn(materialized); - - final KTable mockKTable = mock(KTable.class); - when(mockKGroupedTable.aggregate(any(), any(), any(), any())).thenReturn(mockKTable); - - final SchemaKGroupedTable groupedTable = - buildSchemaKGroupedTable(mockKGroupedTable, materializedFactory); - - // When: - groupedTable.aggregate( - aggregateSchema, - () -> null, - 0, - emptyList(), - Collections.emptyMap(), - null, - valueFormat, - valueSerde, - queryContext); - - // Then: - verify(materializedFactory).create( - eq(keySerde), - same(valueSerde), - eq(StreamsUtil.buildOpName(queryContext.getQueryContext())) - ); - - verify(mockKGroupedTable).aggregate( - any(), - any(), - any(), - same(materialized) - ); - } - @Test public void shouldBuildStepForAggregate() { // Given: - final Map functions = ImmutableMap.of(1, tableFunc); - final SchemaKGroupedTable groupedTable = + final SchemaKGroupedTable kGroupedTable = buildSchemaKGroupedTable(mockKGroupedTable, materializedFactory); - when(aggregateSchema.value()).thenReturn( - ImmutableList.of(mock(Column.class), mock(Column.class))); - // When: - final SchemaKTable result = groupedTable.aggregate( - aggregateSchema, - initializer, + final SchemaKTable result = kGroupedTable.aggregate( + OUT_SCHEMA, + AGG_SCHEMA, 1, - ImmutableList.of(aggCall1), - functions, + ImmutableList.of(SUM, COUNT), null, valueFormat, - topicValueSerDe, - queryContext + queryContext, + queryBuilder ); // Then: @@ -371,37 +201,37 @@ public void shouldBuildStepForAggregate() { equalTo( ExecutionStepFactory.tableAggregate( queryContext, - groupedTable.getSourceTableStep(), - aggregateSchema, + kGroupedTable.getSourceTableStep(), + OUT_SCHEMA, Formats.of(keyFormat, valueFormat, SerdeOption.none()), 1, - ImmutableList.of(aggCall1) + ImmutableList.of(SUM, COUNT), + AGG_SCHEMA ) ) ); } @Test - public void shouldReturnKTableWithAggregateSchema() { + public void shouldReturnKTableWithOutputSchema() { // Given: final SchemaKGroupedTable groupedTable = buildSchemaKGroupedTable(mockKGroupedTable, materializedFactory); // When: final SchemaKTable result = groupedTable.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - emptyMap(), + OUT_SCHEMA, + AGG_SCHEMA, + 1, + ImmutableList.of(SUM, COUNT), null, valueFormat, - topicValueSerDe, - queryContext + queryContext, + queryBuilder ); // Then: - assertThat(result.getSchema(), is(aggregateSchema)); + assertThat(result.getSchema(), is(OUT_SCHEMA)); } @Test(expected = IllegalArgumentException.class) @@ -410,32 +240,23 @@ public void shouldThrowOnColumnCountMismatch() { final SchemaKGroupedTable groupedTable = buildSchemaKGroupedTable(mockKGroupedTable, materializedFactory); - // Agg schema has 2 fields: - givenAggregateSchemaFieldCount(2); - - // Where as params have 1 nonAgg and 2 agg fields: - final Map aggColumns = ImmutableMap.of(2, otherFunc); - // When: groupedTable.aggregate( - aggregateSchema, - initializer, + OUT_SCHEMA, + AGG_SCHEMA, 2, - ImmutableList.of(aggCall1), - aggColumns, + ImmutableList.of(SUM, COUNT), null, valueFormat, - topicValueSerDe, - queryContext + queryContext, + queryBuilder ); } - private void givenAggregateSchemaFieldCount(final int count) { - final List valueFields = IntStream - .range(0, count) - .mapToObj(i -> field) - .collect(Collectors.toList()); - - when(aggregateSchema.value()).thenReturn(valueFields); + private static FunctionCall udaf(final String name) { + return new FunctionCall( + QualifiedName.of(name), + ImmutableList.of(new QualifiedNameReference(QualifiedName.of("IN1"))) + ); } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/TableAggregationFunction.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/TableAggregationFunction.java similarity index 87% rename from ksql-engine/src/main/java/io/confluent/ksql/function/TableAggregationFunction.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/function/TableAggregationFunction.java index fbef04cfe2cc..9249aa63522c 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/TableAggregationFunction.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/TableAggregationFunction.java @@ -13,7 +13,9 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function; +package io.confluent.ksql.execution.function; + +import io.confluent.ksql.function.KsqlAggregateFunction; public interface TableAggregationFunction extends KsqlAggregateFunction { A undo(I valueToUndo, A aggregateValue); diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/function/UdafUtil.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/UdafUtil.java new file mode 100644 index 000000000000..f4463e8e0c2f --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/UdafUtil.java @@ -0,0 +1,61 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.function; + +import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.util.ExpressionTypeManager; +import io.confluent.ksql.function.AggregateFunctionArguments; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.util.KsqlException; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.kafka.connect.data.Schema; + +public final class UdafUtil { + private UdafUtil() { + } + + @SuppressWarnings("deprecation") // Need to migrate away from Connect Schema use. + public static KsqlAggregateFunction resolveAggregateFunction( + final FunctionRegistry functionRegistry, + final FunctionCall functionCall, + final LogicalSchema schema + ) { + try { + final ExpressionTypeManager expressionTypeManager = + new ExpressionTypeManager(schema, functionRegistry); + final List functionArgs = functionCall.getArguments(); + final Schema expressionType = expressionTypeManager.getExpressionSchema(functionArgs.get(0)); + final KsqlAggregateFunction aggregateFunctionInfo = functionRegistry.getAggregate( + functionCall.getName().toString(), + expressionType + ); + + final List args = functionArgs.stream() + .map(Expression::toString) + .collect(Collectors.toList()); + + final int udafIndex = schema.valueColumnIndex(args.get(0)).getAsInt(); + + return aggregateFunctionInfo.getInstance(new AggregateFunctionArguments(udafIndex, args)); + } catch (final Exception e) { + throw new KsqlException("Failed to create aggregate function: " + functionCall, e); + } + } +} diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafAggregator.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafAggregator.java similarity index 94% rename from ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafAggregator.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafAggregator.java index af625865d086..6e28b1671846 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafAggregator.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafAggregator.java @@ -13,7 +13,7 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function.udaf; +package io.confluent.ksql.execution.function.udaf; import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; @@ -119,4 +119,12 @@ public Merger getMerger() { return new GenericRow(columns); }; } + + public int getNonFuncColumnCount() { + return nonFuncColumnCount; + } + + public Map getAggValToAggFunctionMap() { + return aggValToAggFunctionMap; + } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafInitializer.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafInitializer.java similarity index 77% rename from ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafInitializer.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafInitializer.java index e59f41da3e4f..89396331b580 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafInitializer.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafInitializer.java @@ -13,11 +13,12 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function.udaf; +package io.confluent.ksql.execution.function.udaf; +import com.google.common.collect.ImmutableList; import io.confluent.ksql.GenericRow; -import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -25,11 +26,14 @@ public class KudafInitializer implements Initializer { - private final List aggValueSuppliers = new ArrayList<>(); + private final List aggValueSuppliers; private final int nonAggValSize; - public KudafInitializer(final int nonAggValSize) { + public KudafInitializer(final int nonAggValSize, final List aggValueSuppliers) { this.nonAggValSize = nonAggValSize; + this.aggValueSuppliers = ImmutableList.copyOf( + Objects.requireNonNull(aggValueSuppliers, "aggValueSuppliers") + ); } @Override diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafUndoAggregator.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafUndoAggregator.java similarity index 86% rename from ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafUndoAggregator.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafUndoAggregator.java index e5b6989e93de..82cabe5098e1 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafUndoAggregator.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafUndoAggregator.java @@ -13,11 +13,11 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function.udaf; +package io.confluent.ksql.execution.function.udaf; import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.TableAggregationFunction; +import io.confluent.ksql.execution.function.TableAggregationFunction; import java.util.Map; import java.util.Objects; import org.apache.kafka.connect.data.Struct; @@ -53,4 +53,12 @@ public GenericRow apply(final Struct k, final GenericRow rowValue, final Generic aggRowValue.getColumns().get(aggRowIndex)))); return aggRowValue; } + + public int getNonFuncColumnCount() { + return nonFuncColumnCount; + } + + public Map getAggValToAggFunctionMap() { + return aggValToAggFunctionMap; + } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowSelectMapper.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapper.java similarity index 86% rename from ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowSelectMapper.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapper.java index 76954479148d..09860e299ab4 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowSelectMapper.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapper.java @@ -13,7 +13,7 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function.udaf.window; +package io.confluent.ksql.execution.function.udaf.window; import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; @@ -26,14 +26,17 @@ import org.apache.kafka.streams.kstream.Windowed; /** - * Used to handle the special cased {@link WindowStartKudaf} and {@link WindowEndKudaf}. + * Used to handle the special cased {WindowStart} and {WindowEnd}. */ public final class WindowSelectMapper implements ValueMapperWithKey, GenericRow, GenericRow> { + public static final String WINDOW_START_NAME = "WindowStart"; + public static final String WINDOW_END_NAME = "WindowEnd"; + private static final Map WINDOW_FUNCTION_NAMES = ImmutableMap.of( - WindowStartKudaf.getFunctionName().toUpperCase(), Type.StartTime, - WindowEndKudaf.getFunctionName().toUpperCase(), Type.EndTime + WINDOW_START_NAME.toUpperCase(), Type.StartTime, + WINDOW_END_NAME.toUpperCase(), Type.EndTime ); private final Map windowSelects; diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamAggregate.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamAggregate.java index 1a47b2ffe9ab..08b12b674204 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamAggregate.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamAggregate.java @@ -15,31 +15,39 @@ package io.confluent.ksql.execution.plan; import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.schema.ksql.LogicalSchema; import java.util.Collections; import java.util.List; import java.util.Objects; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KTable; @Immutable -public class StreamAggregate implements ExecutionStep { +public class StreamAggregate implements ExecutionStep> { private final ExecutionStepProperties properties; - private final ExecutionStep source; + private final ExecutionStep> source; private final Formats formats; private final int nonFuncColumnCount; private final List aggregations; + private final LogicalSchema aggregationSchema; public StreamAggregate( final ExecutionStepProperties properties, - final ExecutionStep source, + final ExecutionStep> source, final Formats formats, final int nonFuncColumnCount, - final List aggregations) { + final List aggregations, + final LogicalSchema aggregationSchema) { this.properties = Objects.requireNonNull(properties, "properties"); this.source = Objects.requireNonNull(source, "source"); this.formats = Objects.requireNonNull(formats, "formats"); this.nonFuncColumnCount = nonFuncColumnCount; - this.aggregations = Objects.requireNonNull(aggregations); + this.aggregations = Objects.requireNonNull(aggregations, "aggregations"); + this.aggregationSchema = Objects.requireNonNull(aggregationSchema, "aggregationSchema"); } @Override @@ -52,8 +60,24 @@ public List> getSources() { return Collections.singletonList(source); } + public int getNonFuncColumnCount() { + return nonFuncColumnCount; + } + + public List getAggregations() { + return aggregations; + } + + public Formats getFormats() { + return formats; + } + + public LogicalSchema getAggregationSchema() { + return aggregationSchema; + } + @Override - public T build(final KsqlQueryBuilder streamsBuilder) { + public KTable build(final KsqlQueryBuilder streamsBuilder) { throw new UnsupportedOperationException(); } @@ -65,7 +89,7 @@ public boolean equals(final Object o) { if (o == null || getClass() != o.getClass()) { return false; } - final StreamAggregate that = (StreamAggregate) o; + final StreamAggregate that = (StreamAggregate) o; return Objects.equals(properties, that.properties) && Objects.equals(source, that.source) && Objects.equals(formats, that.formats) diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamWindowedAggregate.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamWindowedAggregate.java new file mode 100644 index 000000000000..015f126bf054 --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamWindowedAggregate.java @@ -0,0 +1,114 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.plan; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Windowed; + +public class StreamWindowedAggregate + implements ExecutionStep, GenericRow>> { + private final ExecutionStepProperties properties; + private final ExecutionStep> source; + private final Formats formats; + private final int nonFuncColumnCount; + private final List aggregations; + private final LogicalSchema aggregationSchema; + private final KsqlWindowExpression windowExpression; + + public StreamWindowedAggregate( + final ExecutionStepProperties properties, + final ExecutionStep> source, + final Formats formats, + final int nonFuncColumnCount, + final List aggregations, + final LogicalSchema aggregationSchema, + final KsqlWindowExpression windowExpression) { + this.properties = Objects.requireNonNull(properties, "properties"); + this.source = Objects.requireNonNull(source, "source"); + this.formats = Objects.requireNonNull(formats, "formats"); + this.nonFuncColumnCount = nonFuncColumnCount; + this.aggregations = Objects.requireNonNull(aggregations, "aggregations"); + this.aggregationSchema = Objects.requireNonNull(aggregationSchema, "aggregationSchema"); + this.windowExpression = Objects.requireNonNull(windowExpression, "windowExpression"); + } + + @Override + public ExecutionStepProperties getProperties() { + return properties; + } + + @Override + public List> getSources() { + return Collections.singletonList(source); + } + + public int getNonFuncColumnCount() { + return nonFuncColumnCount; + } + + public List getAggregations() { + return aggregations; + } + + public Formats getFormats() { + return formats; + } + + public LogicalSchema getAggregationSchema() { + return aggregationSchema; + } + + public KsqlWindowExpression getWindowExpression() { + return windowExpression; + } + + @Override + public KTable, GenericRow> build(final KsqlQueryBuilder streamsBuilder) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final StreamWindowedAggregate that = (StreamWindowedAggregate) o; + return Objects.equals(properties, that.properties) + && Objects.equals(source, that.source) + && Objects.equals(formats, that.formats) + && Objects.equals(aggregations, that.aggregations) + && nonFuncColumnCount == that.nonFuncColumnCount; + } + + @Override + public int hashCode() { + + return Objects.hash(properties, source, formats, aggregations, nonFuncColumnCount); + } +} diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableAggregate.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableAggregate.java index 768d38d7948b..4a54c3918706 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableAggregate.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableAggregate.java @@ -15,31 +15,39 @@ package io.confluent.ksql.execution.plan; import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.schema.ksql.LogicalSchema; import java.util.Collections; import java.util.List; import java.util.Objects; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; @Immutable -public class TableAggregate implements ExecutionStep { +public class TableAggregate implements ExecutionStep> { private final ExecutionStepProperties properties; - private final ExecutionStep source; + private final ExecutionStep> source; private final Formats formats; private final int nonFuncColumnCount; private final List aggregations; + private final LogicalSchema aggregateSchema; public TableAggregate( final ExecutionStepProperties properties, - final ExecutionStep source, + final ExecutionStep> source, final Formats formats, final int nonFuncColumnCount, - final List aggregations) { + final List aggregations, + final LogicalSchema aggregateSchema) { this.properties = Objects.requireNonNull(properties, "properties"); this.source = Objects.requireNonNull(source, "source"); this.formats = Objects.requireNonNull(formats, "formats"); this.nonFuncColumnCount = nonFuncColumnCount; this.aggregations = Objects.requireNonNull(aggregations, "aggValToFunctionMap"); + this.aggregateSchema = Objects.requireNonNull(aggregateSchema, "aggregateSchema"); } @Override @@ -52,8 +60,24 @@ public List> getSources() { return Collections.singletonList(source); } + public Formats getFormats() { + return formats; + } + + public List getAggregations() { + return aggregations; + } + + public int getNonFuncColumnCount() { + return nonFuncColumnCount; + } + + public LogicalSchema getAggregateSchema() { + return aggregateSchema; + } + @Override - public T build(final KsqlQueryBuilder builder) { + public KTable build(final KsqlQueryBuilder builder) { throw new UnsupportedOperationException(); } @@ -65,7 +89,7 @@ public boolean equals(final Object o) { if (o == null || getClass() != o.getClass()) { return false; } - final TableAggregate that = (TableAggregate) o; + final TableAggregate that = (TableAggregate) o; return Objects.equals(properties, that.properties) && Objects.equals(source, that.source) && Objects.equals(formats, that.formats) diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/HoppingWindowExpression.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/HoppingWindowExpression.java similarity index 72% rename from ksql-parser/src/main/java/io/confluent/ksql/parser/tree/HoppingWindowExpression.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/windows/HoppingWindowExpression.java index 95f33723120d..01f4ce062884 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/HoppingWindowExpression.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/HoppingWindowExpression.java @@ -13,13 +13,11 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.parser.tree; +package io.confluent.ksql.execution.windows; import static java.util.Objects.requireNonNull; import com.google.errorprone.annotations.Immutable; -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.UdafAggregator; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.parser.NodeLocation; import io.confluent.ksql.serde.WindowInfo; @@ -27,12 +25,6 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; -import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; -import org.apache.kafka.streams.kstream.KGroupedStream; -import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; -import org.apache.kafka.streams.kstream.TimeWindows; @Immutable public class HoppingWindowExpression extends KsqlWindowExpression { @@ -73,8 +65,24 @@ public WindowInfo getWindowInfo() { ); } + public TimeUnit getSizeUnit() { + return sizeUnit; + } + + public long getSize() { + return size; + } + + public TimeUnit getAdvanceByUnit() { + return advanceByUnit; + } + + public long getAdvanceBy() { + return advanceBy; + } + @Override - public R accept(final AstVisitor visitor, final C context) { + public R accept(final WindowVisitor visitor, final C context) { return visitor.visitHoppingWindowExpression(this, context); } @@ -102,21 +110,4 @@ public boolean equals(final Object o) { && hoppingWindowExpression.advanceBy == advanceBy && hoppingWindowExpression .advanceByUnit == advanceByUnit; } - - @SuppressWarnings("unchecked") - @Override - public KTable applyAggregate( - final KGroupedStream groupedStream, - final Initializer initializer, - final UdafAggregator aggregator, - final Materialized materialized - ) { - final TimeWindows windows = TimeWindows - .of(Duration.ofMillis(sizeUnit.toMillis(size))) - .advanceBy(Duration.ofMillis(advanceByUnit.toMillis(advanceBy))); - - return groupedStream - .windowedBy(windows) - .aggregate(initializer, aggregator, materialized); - } } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/KsqlWindowExpression.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/KsqlWindowExpression.java new file mode 100644 index 000000000000..f15717187065 --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/KsqlWindowExpression.java @@ -0,0 +1,34 @@ +/* + * Copyright 2018 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.windows; + +import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.parser.Node; +import io.confluent.ksql.parser.NodeLocation; +import io.confluent.ksql.serde.WindowInfo; +import java.util.Optional; + +@Immutable +public abstract class KsqlWindowExpression extends Node { + + KsqlWindowExpression(final Optional nodeLocation) { + super(nodeLocation); + } + + public abstract WindowInfo getWindowInfo(); + + public abstract R accept(WindowVisitor visitor, C context); +} diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/SessionWindowExpression.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/SessionWindowExpression.java similarity index 66% rename from ksql-parser/src/main/java/io/confluent/ksql/parser/tree/SessionWindowExpression.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/windows/SessionWindowExpression.java index f8c4154f5b01..ed22f3f4d06e 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/SessionWindowExpression.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/SessionWindowExpression.java @@ -13,26 +13,17 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.parser.tree; +package io.confluent.ksql.execution.windows; import static java.util.Objects.requireNonNull; import com.google.errorprone.annotations.Immutable; -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.UdafAggregator; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.parser.NodeLocation; import io.confluent.ksql.serde.WindowInfo; -import java.time.Duration; import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; -import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; -import org.apache.kafka.streams.kstream.KGroupedStream; -import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; -import org.apache.kafka.streams.kstream.SessionWindows; @Immutable public class SessionWindowExpression extends KsqlWindowExpression { @@ -54,13 +45,21 @@ public SessionWindowExpression( this.sizeUnit = requireNonNull(sizeUnit, "sizeUnit"); } + public TimeUnit getSizeUnit() { + return sizeUnit; + } + + public long getGap() { + return gap; + } + @Override public WindowInfo getWindowInfo() { return WindowInfo.of(WindowType.SESSION, Optional.empty()); } @Override - public R accept(final AstVisitor visitor, final C context) { + public R accept(final WindowVisitor visitor, final C context) { return visitor.visitSessionWindowExpression(this, context); } @@ -85,18 +84,4 @@ public boolean equals(final Object o) { final SessionWindowExpression sessionWindowExpression = (SessionWindowExpression) o; return sessionWindowExpression.gap == gap && sessionWindowExpression.sizeUnit == sizeUnit; } - - @SuppressWarnings("unchecked") - @Override - public KTable applyAggregate(final KGroupedStream groupedStream, - final Initializer initializer, - final UdafAggregator aggregator, - final Materialized materialized) { - - final SessionWindows windows = SessionWindows.with(Duration.ofMillis(sizeUnit.toMillis(gap))); - - return groupedStream - .windowedBy(windows) - .aggregate(initializer, aggregator, aggregator.getMerger(), materialized); - } } diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/TumblingWindowExpression.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/TumblingWindowExpression.java similarity index 70% rename from ksql-parser/src/main/java/io/confluent/ksql/parser/tree/TumblingWindowExpression.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/windows/TumblingWindowExpression.java index 6db9121cbfe5..45b5602dcad2 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/TumblingWindowExpression.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/TumblingWindowExpression.java @@ -13,13 +13,11 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.parser.tree; +package io.confluent.ksql.execution.windows; import static java.util.Objects.requireNonNull; import com.google.errorprone.annotations.Immutable; -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.UdafAggregator; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.parser.NodeLocation; import io.confluent.ksql.serde.WindowInfo; @@ -27,12 +25,6 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; -import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; -import org.apache.kafka.streams.kstream.KGroupedStream; -import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; -import org.apache.kafka.streams.kstream.TimeWindows; @Immutable public class TumblingWindowExpression extends KsqlWindowExpression { @@ -62,8 +54,16 @@ public WindowInfo getWindowInfo() { ); } + public TimeUnit getSizeUnit() { + return sizeUnit; + } + + public long getSize() { + return size; + } + @Override - public R accept(final AstVisitor visitor, final C context) { + public R accept(final WindowVisitor visitor, final C context) { return visitor.visitTumblingWindowExpression(this, context); } @@ -88,19 +88,4 @@ public boolean equals(final Object o) { final TumblingWindowExpression tumblingWindowExpression = (TumblingWindowExpression) o; return tumblingWindowExpression.size == size && tumblingWindowExpression.sizeUnit == sizeUnit; } - - @SuppressWarnings("unchecked") - @Override - public KTable applyAggregate(final KGroupedStream groupedStream, - final Initializer initializer, - final UdafAggregator aggregator, - final Materialized materialized) { - - final TimeWindows windows = TimeWindows.of(Duration.ofMillis(sizeUnit.toMillis(size))); - - return groupedStream - .windowedBy(windows) - .aggregate(initializer, aggregator, materialized); - - } } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/WindowVisitor.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/WindowVisitor.java new file mode 100644 index 000000000000..f32ffb2d088f --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/WindowVisitor.java @@ -0,0 +1,24 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.windows; + +public interface WindowVisitor { + R visitHoppingWindowExpression(HoppingWindowExpression hoppingWindowExpression, C ctx); + + R visitSessionWindowExpression(SessionWindowExpression sessionWindowExpression, C ctx); + + R visitTumblingWindowExpression(TumblingWindowExpression tumblingWindowExpression, C ctx); +} diff --git a/ksql-execution/src/test/java/io/confluent/ksql/execution/function/UdafUtilTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/function/UdafUtilTest.java new file mode 100644 index 000000000000..62d804d67017 --- /dev/null +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/function/UdafUtilTest.java @@ -0,0 +1,109 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.function; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.expression.tree.QualifiedName; +import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.function.AggregateFunctionArguments; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import org.apache.kafka.connect.data.Schema; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class UdafUtilTest { + private static final LogicalSchema SCHEMA = LogicalSchema.builder() + .valueColumn("FOO", SqlTypes.INTEGER) + .valueColumn("BAR", SqlTypes.BIGINT) + .build(); + private static final FunctionCall FUNCTION_CALL = new FunctionCall( + QualifiedName.of("AGG"), + ImmutableList.of(new QualifiedNameReference(QualifiedName.of("BAR"))) + ); + + @Mock + private FunctionRegistry functionRegistry; + @Mock + private KsqlAggregateFunction function; + @Mock + private KsqlAggregateFunction resolved; + @Captor + private ArgumentCaptor argumentsCaptor; + + @Before + @SuppressWarnings("unchecked") + public void init() { + when(functionRegistry.getAggregate(any(), any())).thenReturn(function); + when(function.getInstance(any())).thenReturn(resolved); + } + + @Test + public void shouldResolveUDAF() { + // When: + final KsqlAggregateFunction returned = + UdafUtil.resolveAggregateFunction(functionRegistry, FUNCTION_CALL, SCHEMA); + + // Then: + assertThat(returned, is(resolved)); + } + + @Test + public void shouldGetAggregateWithCorrectName() { + // When: + UdafUtil.resolveAggregateFunction(functionRegistry, FUNCTION_CALL, SCHEMA); + + // Then: + verify(functionRegistry).getAggregate(eq("AGG"), any()); + } + + @Test + public void shouldGetAggregateWithCorrectType() { + // When: + UdafUtil.resolveAggregateFunction(functionRegistry, FUNCTION_CALL, SCHEMA); + + // Then: + verify(functionRegistry).getAggregate(any(), eq(Schema.OPTIONAL_INT64_SCHEMA)); + } + + @Test + public void shouldResolveWithCorrectArgs() { + // When: + UdafUtil.resolveAggregateFunction(functionRegistry, FUNCTION_CALL, SCHEMA); + + // Then: + verify(function).getInstance(argumentsCaptor.capture()); + final AggregateFunctionArguments arguments = argumentsCaptor.getValue(); + assertThat(arguments.udafIndex(), equalTo(1)); + } +} \ No newline at end of file diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java index c8c82cc4ad82..43d81eff2466 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java @@ -52,6 +52,9 @@ import io.confluent.ksql.execution.expression.tree.TimeLiteral; import io.confluent.ksql.execution.expression.tree.TimestampLiteral; import io.confluent.ksql.execution.expression.tree.WhenClause; +import io.confluent.ksql.execution.windows.HoppingWindowExpression; +import io.confluent.ksql.execution.windows.SessionWindowExpression; +import io.confluent.ksql.execution.windows.TumblingWindowExpression; import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.metastore.model.DataSource; import io.confluent.ksql.parser.SqlBaseParser.CreateConnectorContext; @@ -85,7 +88,6 @@ import io.confluent.ksql.parser.tree.Explain; import io.confluent.ksql.parser.tree.GroupBy; import io.confluent.ksql.parser.tree.GroupingElement; -import io.confluent.ksql.parser.tree.HoppingWindowExpression; import io.confluent.ksql.parser.tree.InsertInto; import io.confluent.ksql.parser.tree.InsertValues; import io.confluent.ksql.parser.tree.Join; @@ -108,7 +110,6 @@ import io.confluent.ksql.parser.tree.RunScript; import io.confluent.ksql.parser.tree.Select; import io.confluent.ksql.parser.tree.SelectItem; -import io.confluent.ksql.parser.tree.SessionWindowExpression; import io.confluent.ksql.parser.tree.SetProperty; import io.confluent.ksql.parser.tree.ShowColumns; import io.confluent.ksql.parser.tree.SimpleGroupBy; @@ -120,7 +121,6 @@ import io.confluent.ksql.parser.tree.TableElement.Namespace; import io.confluent.ksql.parser.tree.TableElements; import io.confluent.ksql.parser.tree.TerminateQuery; -import io.confluent.ksql.parser.tree.TumblingWindowExpression; import io.confluent.ksql.parser.tree.UnsetProperty; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.parser.tree.WithinExpression; diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/rewrite/StatementRewriter.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/rewrite/StatementRewriter.java index e9101214a346..2e0ee33f1795 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/rewrite/StatementRewriter.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/parser/rewrite/StatementRewriter.java @@ -30,7 +30,6 @@ import io.confluent.ksql.parser.tree.GroupingElement; import io.confluent.ksql.parser.tree.InsertInto; import io.confluent.ksql.parser.tree.Join; -import io.confluent.ksql.parser.tree.KsqlWindowExpression; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.RegisterType; import io.confluent.ksql.parser.tree.Relation; @@ -214,14 +213,8 @@ protected AstNode visitWindowExpression(final WindowExpression node, final C con return new WindowExpression( node.getLocation(), node.getWindowName(), - (KsqlWindowExpression) rewriter.apply(node.getKsqlWindowExpression(), context)); - } - - @Override - protected AstNode visitKsqlWindowExpression( - final KsqlWindowExpression node, - final C context) { - return node; + node.getKsqlWindowExpression() + ); } @Override diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/AstVisitor.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/AstVisitor.java index f232dda4adc3..64ef262434b5 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/AstVisitor.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/AstVisitor.java @@ -92,22 +92,6 @@ protected R visitWindowExpression(final WindowExpression node, final C context) return visitNode(node, context); } - protected R visitKsqlWindowExpression(final KsqlWindowExpression node, final C context) { - return visitNode(node, context); - } - - protected R visitTumblingWindowExpression(final TumblingWindowExpression node, final C context) { - return visitKsqlWindowExpression(node, context); - } - - protected R visitHoppingWindowExpression(final HoppingWindowExpression node, final C context) { - return visitKsqlWindowExpression(node, context); - } - - protected R visitSessionWindowExpression(final SessionWindowExpression node, final C context) { - return visitKsqlWindowExpression(node, context); - } - protected R visitTableElement(final TableElement node, final C context) { return visitNode(node, context); } diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/KsqlWindowExpression.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/KsqlWindowExpression.java deleted file mode 100644 index c7a48b6a7211..000000000000 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/KsqlWindowExpression.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2018 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.parser.tree; - -import com.google.errorprone.annotations.Immutable; -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.UdafAggregator; -import io.confluent.ksql.parser.NodeLocation; -import io.confluent.ksql.serde.WindowInfo; -import java.util.Optional; -import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; -import org.apache.kafka.streams.kstream.KGroupedStream; -import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; - -@Immutable -public abstract class KsqlWindowExpression extends AstNode { - - KsqlWindowExpression(final Optional location) { - super(location); - } - - public abstract KTable applyAggregate(KGroupedStream groupedStream, - Initializer initializer, - UdafAggregator aggregator, - Materialized materialized); - - public abstract WindowInfo getWindowInfo(); - - @Override - public R accept(final AstVisitor visitor, final C context) { - return visitor.visitKsqlWindowExpression(this, context); - } -} diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/WindowExpression.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/WindowExpression.java index 8d889157ca3d..32bdaae39da2 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/WindowExpression.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/WindowExpression.java @@ -18,6 +18,7 @@ import static java.util.Objects.requireNonNull; import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.parser.NodeLocation; import java.util.Objects; import java.util.Optional; diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/rewrite/StatementRewriterTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/rewrite/StatementRewriterTest.java index 17c3f062b71b..edbda0ab284b 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/rewrite/StatementRewriterTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/rewrite/StatementRewriterTest.java @@ -23,7 +23,7 @@ import io.confluent.ksql.parser.tree.Join; import io.confluent.ksql.parser.tree.Join.Type; import io.confluent.ksql.parser.tree.JoinCriteria; -import io.confluent.ksql.parser.tree.KsqlWindowExpression; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.Relation; import io.confluent.ksql.parser.tree.ResultMaterialization; @@ -357,11 +357,8 @@ public void shouldRewriteJoinWithWindowExpression() { public void shouldRewriteWindowExpression() { // Given: final KsqlWindowExpression ksqlWindowExpression = mock(KsqlWindowExpression.class); - final KsqlWindowExpression rewrittenKsqlWindowExpression = mock(KsqlWindowExpression.class); final WindowExpression windowExpression = new WindowExpression(location, "name", ksqlWindowExpression); - when(mockRewriter.apply(ksqlWindowExpression, context)) - .thenReturn(rewrittenKsqlWindowExpression); // When: final AstNode rewritten = rewriter.rewrite(windowExpression, context); @@ -369,7 +366,7 @@ public void shouldRewriteWindowExpression() { // Then: assertThat( rewritten, - equalTo(new WindowExpression(location, "name", rewrittenKsqlWindowExpression)) + equalTo(new WindowExpression(location, "name", ksqlWindowExpression)) ); } diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/HoppingWindowExpressionTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/HoppingWindowExpressionTest.java index edf41face930..90865a4ef052 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/HoppingWindowExpressionTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/HoppingWindowExpressionTest.java @@ -1,4 +1,4 @@ -/* + /* * Copyright 2018 Confluent Inc. * * Licensed under the Confluent Community License (the "License"); you may not use @@ -27,6 +27,7 @@ import com.google.common.testing.EqualsTester; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.windows.HoppingWindowExpression; import io.confluent.ksql.function.UdafAggregator; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.parser.NodeLocation; @@ -54,27 +55,6 @@ public class HoppingWindowExpressionTest { public static final NodeLocation SOME_LOCATION = new NodeLocation(0, 0); public static final NodeLocation OTHER_LOCATION = new NodeLocation(1, 0); - @Mock - private KGroupedStream stream; - @Mock - private TimeWindowedKStream windowedKStream; - @Mock - private UdafAggregator aggregator; - @Mock - private Initializer initializer; - @Mock - private Materialized> store; - private HoppingWindowExpression windowExpression; - - @Before - public void setUp() { - windowExpression = new HoppingWindowExpression(10, SECONDS, 4, TimeUnit.MILLISECONDS); - - when(stream - .windowedBy(any(TimeWindows.class))) - .thenReturn(windowedKStream); - } - @Test public void shouldImplementHashCodeAndEqualsProperty() { new EqualsTester() @@ -100,18 +80,6 @@ public void shouldImplementHashCodeAndEqualsProperty() { .testEquals(); } - @Test - public void shouldCreateHoppingWindowAggregate() { - // When: - windowExpression.applyAggregate(stream, initializer, aggregator, store); - - // Then: - verify(stream) - .windowedBy(TimeWindows.of(Duration.ofSeconds(10)).advanceBy(Duration.ofMillis(4L))); - - verify(windowedKStream).aggregate(initializer, aggregator, store); - } - @Test public void shouldReturnWindowInfo() { assertThat(new HoppingWindowExpression(10, SECONDS, 20, MINUTES).getWindowInfo(), diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/ParserModelTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/ParserModelTest.java index 038d2ce7cdba..fa43b9212a1b 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/ParserModelTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/ParserModelTest.java @@ -31,6 +31,8 @@ import io.confluent.ksql.execution.expression.tree.QualifiedName; import io.confluent.ksql.execution.expression.tree.StringLiteral; import io.confluent.ksql.execution.expression.tree.Type; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; +import io.confluent.ksql.execution.windows.TumblingWindowExpression; import io.confluent.ksql.parser.properties.with.CreateSourceAsProperties; import io.confluent.ksql.parser.properties.with.CreateSourceProperties; import io.confluent.ksql.properties.with.CommonCreateConfigs; diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/SessionWindowExpressionTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/SessionWindowExpressionTest.java index 68fee27e3d6a..0740e2d0f0fb 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/SessionWindowExpressionTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/SessionWindowExpressionTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.when; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.windows.SessionWindowExpression; import io.confluent.ksql.function.UdafAggregator; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.serde.WindowInfo; @@ -46,39 +47,11 @@ @RunWith(MockitoJUnitRunner.class) public class SessionWindowExpressionTest { - @Mock - private KGroupedStream stream; - @Mock - private SessionWindowedKStream windowedKStream; - @Mock - private UdafAggregator aggregator; - @Mock - private Initializer initializer; - @Mock - private Materialized> store; - @Mock - private Merger merger; private SessionWindowExpression windowExpression; @Before public void setUp() { windowExpression = new SessionWindowExpression(5, TimeUnit.SECONDS); - - when(stream - .windowedBy(any(SessionWindows.class))) - .thenReturn(windowedKStream); - - when(aggregator.getMerger()).thenReturn(merger); - } - - @Test - public void shouldCreateSessionWindowed() { - // When: - windowExpression.applyAggregate(stream, initializer, aggregator, store); - - // Then: - verify(stream).windowedBy(SessionWindows.with(Duration.ofSeconds(5))); - verify(windowedKStream).aggregate(initializer, aggregator, merger, store); } @Test diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/TumblingWindowExpressionTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/TumblingWindowExpressionTest.java index 8c587e13b0c0..158f343a57e4 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/TumblingWindowExpressionTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/TumblingWindowExpressionTest.java @@ -18,65 +18,17 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.UdafAggregator; +import io.confluent.ksql.execution.windows.TumblingWindowExpression; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.serde.WindowInfo; import java.time.Duration; import java.util.Optional; -import java.util.concurrent.TimeUnit; -import org.apache.kafka.common.utils.Bytes; -import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; -import org.apache.kafka.streams.kstream.KGroupedStream; -import org.apache.kafka.streams.kstream.Materialized; -import org.apache.kafka.streams.kstream.TimeWindowedKStream; -import org.apache.kafka.streams.kstream.TimeWindows; -import org.apache.kafka.streams.state.WindowStore; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; @RunWith(MockitoJUnitRunner.class) public class TumblingWindowExpressionTest { - - @Mock - private KGroupedStream stream; - @Mock - private TimeWindowedKStream windowedKStream; - @Mock - private UdafAggregator aggregator; - @Mock - private Initializer initializer; - @Mock - private Materialized> store; - private TumblingWindowExpression windowExpression; - - @Before - public void setUp() { - windowExpression = new TumblingWindowExpression(10, TimeUnit.SECONDS); - - when(stream - .windowedBy(any(TimeWindows.class))) - .thenReturn(windowedKStream); - } - - @Test - public void shouldCreateTumblingWindowAggregate() { - // When: - windowExpression.applyAggregate(stream, initializer, aggregator, store); - - // Then: - verify(stream).windowedBy(TimeWindows.of(Duration.ofSeconds(10))); - verify(windowedKStream).aggregate(initializer, aggregator, store); - } - @Test public void shouldReturnWindowInfo() { assertThat(new TumblingWindowExpression(11, SECONDS).getWindowInfo(), diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/WindowExpressionTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/WindowExpressionTest.java index 0bf814abed79..f745415fbe7a 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/WindowExpressionTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/WindowExpressionTest.java @@ -18,6 +18,7 @@ import static org.mockito.Mockito.mock; import com.google.common.testing.EqualsTester; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.parser.NodeLocation; import java.util.Optional; import org.junit.Test; diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateBuilderUtils.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateBuilderUtils.java new file mode 100644 index 000000000000..5d0c5d4a534e --- /dev/null +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateBuilderUtils.java @@ -0,0 +1,57 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.serde.KeySerde; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.state.KeyValueStore; + +public final class AggregateBuilderUtils { + private AggregateBuilderUtils() { + } + + static Materialized> buildMaterialized( + final QueryContext queryContext, + final LogicalSchema aggregateSchema, + final Formats formats, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory) { + final PhysicalSchema physicalAggregationSchema = PhysicalSchema.from( + aggregateSchema, + formats.getOptions() + ); + final KeySerde keySerde = queryBuilder.buildKeySerde( + formats.getKeyFormat().getFormatInfo(), + physicalAggregationSchema, + queryContext + ); + final Serde valueSerde = queryBuilder.buildValueSerde( + formats.getValueFormat().getFormatInfo(), + physicalAggregationSchema, + queryContext + ); + return materializedFactory.create(keySerde, valueSerde, StreamsUtil.buildOpName(queryContext)); + } +} diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateParams.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateParams.java new file mode 100644 index 000000000000..145dcdf1c18d --- /dev/null +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateParams.java @@ -0,0 +1,96 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import com.google.common.collect.ImmutableMap; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.function.TableAggregationFunction; +import io.confluent.ksql.execution.function.UdafUtil; +import io.confluent.ksql.execution.function.udaf.KudafAggregator; +import io.confluent.ksql.execution.function.udaf.KudafInitializer; +import io.confluent.ksql.execution.function.udaf.KudafUndoAggregator; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +public final class AggregateParams { + private final KudafInitializer initializer; + private final int initialUdafIndex; + private final Map indexToFunction; + + AggregateParams( + final LogicalSchema internalSchema, + final int initialUdafIndex, + final FunctionRegistry functionRegistry, + final List functionList + ) { + final List initialValueSuppliers = new LinkedList<>(); + int udafIndexInAggSchema = initialUdafIndex; + final Map indexToFunction = new HashMap<>(); + for (final FunctionCall functionCall : functionList) { + final KsqlAggregateFunction aggregateFunction = UdafUtil.resolveAggregateFunction( + functionRegistry, + functionCall, + internalSchema + ); + + indexToFunction.put(udafIndexInAggSchema++, aggregateFunction); + initialValueSuppliers.add(aggregateFunction.getInitialValueSupplier()); + } + this.initialUdafIndex = initialUdafIndex; + this.initializer = new KudafInitializer(initialUdafIndex, initialValueSuppliers); + this.indexToFunction = ImmutableMap.copyOf(indexToFunction); + } + + public KudafInitializer getInitializer() { + return initializer; + } + + public KudafAggregator getAggregator() { + return new KudafAggregator(initialUdafIndex, indexToFunction); + } + + public KudafUndoAggregator getUndoAggregator() { + final Map indexToUndo = + indexToFunction.keySet() + .stream() + .collect( + Collectors.toMap( + k -> k, + k -> ((TableAggregationFunction) indexToFunction.get(k)))); + return new KudafUndoAggregator(initialUdafIndex, indexToUndo); + } + + public WindowSelectMapper getWindowSelectMapper() { + return new WindowSelectMapper(indexToFunction); + } + + public interface Factory { + AggregateParams create( + LogicalSchema internalSchema, + int initialUdafIndex, + FunctionRegistry functionRegistry, + List functionList + ); + } +} diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java index 9f1e3ca89d5d..c3ff96599917 100644 --- a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java @@ -36,12 +36,14 @@ import io.confluent.ksql.execution.plan.StreamStreamJoin; import io.confluent.ksql.execution.plan.StreamTableJoin; import io.confluent.ksql.execution.plan.StreamToTable; +import io.confluent.ksql.execution.plan.StreamWindowedAggregate; import io.confluent.ksql.execution.plan.TableAggregate; import io.confluent.ksql.execution.plan.TableFilter; import io.confluent.ksql.execution.plan.TableGroupBy; import io.confluent.ksql.execution.plan.TableMapValues; import io.confluent.ksql.execution.plan.TableSink; import io.confluent.ksql.execution.plan.TableTableJoin; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.util.timestamp.TimestampExtractionPolicy; import java.time.Duration; @@ -302,22 +304,45 @@ public static TableTableJoin> tableTableJoin( ); } - public static StreamAggregate, KGroupedStream> - streamAggregate( - final QueryContext.Stacker stacker, - final ExecutionStep> sourceStep, - final LogicalSchema resultSchema, - final Formats formats, - final int nonFuncColumnCount, - final List aggregations + public static StreamAggregate streamAggregate( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final LogicalSchema resultSchema, + final Formats formats, + final int nonFuncColumnCount, + final List aggregations, + final LogicalSchema aggregateSchema + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamAggregate( + new DefaultExecutionStepProperties(resultSchema, queryContext), + sourceStep, + formats, + nonFuncColumnCount, + aggregations, + aggregateSchema + ); + } + + public static StreamWindowedAggregate streamWindowedAggregate( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final LogicalSchema resultSchema, + final Formats formats, + final int nonFuncColumnCount, + final List aggregations, + final LogicalSchema aggregateSchema, + final KsqlWindowExpression window ) { final QueryContext queryContext = stacker.getQueryContext(); - return new StreamAggregate<>( + return new StreamWindowedAggregate( new DefaultExecutionStepProperties(resultSchema, queryContext), sourceStep, formats, nonFuncColumnCount, - aggregations + aggregations, + aggregateSchema, + window ); } @@ -349,22 +374,23 @@ public static StreamGroupByKey streamGroupByKey( ); } - public static TableAggregate, KGroupedTable> - tableAggregate( - final QueryContext.Stacker stacker, - final ExecutionStep> sourceStep, - final LogicalSchema resultSchema, - final Formats formats, - final int nonFuncColumnCount, - final List aggregations + public static TableAggregate tableAggregate( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final LogicalSchema resultSchema, + final Formats formats, + final int nonFuncColumnCount, + final List aggregations, + final LogicalSchema aggregateSchema ) { final QueryContext queryContext = stacker.getQueryContext(); - return new TableAggregate<>( + return new TableAggregate( new DefaultExecutionStepProperties(resultSchema, queryContext), sourceStep, formats, nonFuncColumnCount, - aggregations + aggregations, + aggregateSchema ); } diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamAggregateBuilder.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamAggregateBuilder.java new file mode 100644 index 000000000000..56ac16b5c07c --- /dev/null +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamAggregateBuilder.java @@ -0,0 +1,227 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.StreamAggregate; +import io.confluent.ksql.execution.plan.StreamWindowedAggregate; +import io.confluent.ksql.execution.windows.HoppingWindowExpression; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; +import io.confluent.ksql.execution.windows.SessionWindowExpression; +import io.confluent.ksql.execution.windows.TumblingWindowExpression; +import io.confluent.ksql.execution.windows.WindowVisitor; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.serde.KeySerde; +import java.time.Duration; +import java.util.Objects; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueStore; + +public final class StreamAggregateBuilder { + private StreamAggregateBuilder() { + } + + public static KTable build( + final KGroupedStream groupedStream, + final StreamAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory) { + return build(groupedStream, aggregate, queryBuilder, materializedFactory, AggregateParams::new); + } + + static KTable build( + final KGroupedStream kgroupedStream, + final StreamAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory, + final AggregateParams.Factory aggregateParamsFactory) { + final LogicalSchema sourceSchema = aggregate.getSources().get(0).getSchema(); + final int nonFuncColumns = aggregate.getNonFuncColumnCount(); + final AggregateParams aggregateParams = aggregateParamsFactory.create( + sourceSchema, + nonFuncColumns, + queryBuilder.getFunctionRegistry(), + aggregate.getAggregations() + ); + final Materialized> materialized = + AggregateBuilderUtils.buildMaterialized( + aggregate.getProperties().getQueryContext(), + aggregate.getAggregationSchema(), + aggregate.getFormats(), + queryBuilder, + materializedFactory + ); + final KTable aggregated = kgroupedStream.aggregate( + aggregateParams.getInitializer(), + aggregateParams.getAggregator(), + materialized + ); + return aggregated.mapValues(aggregateParams.getAggregator().getResultMapper()); + } + + public static KTable, GenericRow> build( + final KGroupedStream groupedStream, + final StreamWindowedAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory + ) { + return build(groupedStream, aggregate, queryBuilder, materializedFactory, AggregateParams::new); + } + + static KTable, GenericRow> build( + final KGroupedStream groupedStream, + final StreamWindowedAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory, + final AggregateParams.Factory aggregateParamsFactory + ) { + final LogicalSchema sourceSchema = aggregate.getSources().get(0).getSchema(); + final int nonFuncColumns = aggregate.getNonFuncColumnCount(); + final AggregateParams aggregateParams = aggregateParamsFactory.create( + sourceSchema, + nonFuncColumns, + queryBuilder.getFunctionRegistry(), + aggregate.getAggregations() + ); + final KsqlWindowExpression ksqlWindowExpression = aggregate.getWindowExpression(); + final KTable, GenericRow> aggregated = ksqlWindowExpression.accept( + new WindowedAggregator( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParams + ), + null + ); + final KTable, GenericRow> reduced = aggregated.mapValues( + aggregateParams.getAggregator().getResultMapper() + ); + final WindowSelectMapper windowSelectMapper = aggregateParams.getWindowSelectMapper(); + if (!windowSelectMapper.hasSelects()) { + return reduced; + } + return reduced.mapValues(windowSelectMapper); + } + + private static class WindowedAggregator + implements WindowVisitor, GenericRow>, Void> { + final QueryContext queryContext; + final Formats formats; + final KGroupedStream groupedStream; + final KsqlQueryBuilder queryBuilder; + final MaterializedFactory materializedFactory; + final KeySerde keySerde; + final Serde valueSerde; + final AggregateParams aggregateParams; + + WindowedAggregator( + final KGroupedStream groupedStream, + final StreamWindowedAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory, + final AggregateParams aggregateParams) { + Objects.requireNonNull(aggregate, "aggregate"); + this.groupedStream = Objects.requireNonNull(groupedStream, "groupedStream"); + this.queryBuilder = Objects.requireNonNull(queryBuilder, "queryBuilder"); + this.materializedFactory = Objects.requireNonNull(materializedFactory, "materializedFactory"); + this.aggregateParams = Objects.requireNonNull(aggregateParams, "aggregateParams"); + this.queryContext = aggregate.getProperties().getQueryContext(); + this.formats = aggregate.getFormats(); + final PhysicalSchema physicalSchema = PhysicalSchema.from( + aggregate.getAggregationSchema(), + formats.getOptions() + ); + keySerde = queryBuilder.buildKeySerde( + formats.getKeyFormat().getFormatInfo(), + physicalSchema, + queryContext + ); + valueSerde = queryBuilder.buildValueSerde( + formats.getValueFormat().getFormatInfo(), + physicalSchema, + queryContext + ); + } + + @Override + public KTable, GenericRow> visitHoppingWindowExpression( + final HoppingWindowExpression window, + final Void ctx) { + final TimeWindows windows = TimeWindows + .of(Duration.ofMillis(window.getSizeUnit().toMillis(window.getSize()))) + .advanceBy( + Duration.ofMillis(window.getAdvanceByUnit().toMillis(window.getAdvanceBy())) + ); + + return groupedStream + .windowedBy(windows) + .aggregate( + aggregateParams.getInitializer(), + aggregateParams.getAggregator(), + materializedFactory.create( + keySerde, valueSerde, StreamsUtil.buildOpName(queryContext)) + ); + } + + @Override + public KTable, GenericRow> visitSessionWindowExpression( + final SessionWindowExpression window, + final Void ctx) { + final SessionWindows windows = SessionWindows.with( + Duration.ofMillis(window.getSizeUnit().toMillis(window.getGap())) + ); + return groupedStream + .windowedBy(windows) + .aggregate( + aggregateParams.getInitializer(), + aggregateParams.getAggregator(), + aggregateParams.getAggregator().getMerger(), + materializedFactory.create( + keySerde, valueSerde, StreamsUtil.buildOpName(queryContext)) + ); + } + + @Override + public KTable, GenericRow> visitTumblingWindowExpression( + final TumblingWindowExpression window, + final Void ctx) { + final TimeWindows windows = TimeWindows.of( + Duration.ofMillis(window.getSizeUnit().toMillis(window.getSize()))); + return groupedStream + .windowedBy(windows) + .aggregate( + aggregateParams.getInitializer(), + aggregateParams.getAggregator(), + materializedFactory.create( + keySerde, valueSerde, StreamsUtil.buildOpName(queryContext)) + ); + } + } +} diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableAggregateBuilder.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableAggregateBuilder.java new file mode 100644 index 000000000000..0c716f943ae6 --- /dev/null +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableAggregateBuilder.java @@ -0,0 +1,76 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.plan.TableAggregate; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.state.KeyValueStore; + +public final class TableAggregateBuilder { + private TableAggregateBuilder() { + } + + public static KTable build( + final KGroupedTable kgroupedTable, + final TableAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory) { + return build( + kgroupedTable, + aggregate, + queryBuilder, + materializedFactory, + AggregateParams::new + ); + } + + public static KTable build( + final KGroupedTable kgroupedTable, + final TableAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory, + final AggregateParams.Factory aggregateParamsFactory) { + final LogicalSchema sourceSchema = aggregate.getSources().get(0).getSchema(); + final int nonFuncColumns = aggregate.getNonFuncColumnCount(); + final AggregateParams aggregateParams = aggregateParamsFactory.create( + sourceSchema, + nonFuncColumns, + queryBuilder.getFunctionRegistry(), + aggregate.getAggregations() + ); + final Materialized> materialized = + AggregateBuilderUtils.buildMaterialized( + aggregate.getProperties().getQueryContext(), + aggregate.getAggregateSchema(), + aggregate.getFormats(), + queryBuilder, + materializedFactory + ); + return kgroupedTable.aggregate( + aggregateParams.getInitializer(), + aggregateParams.getAggregator(), + aggregateParams.getUndoAggregator(), + materialized + ).mapValues(aggregateParams.getAggregator().getResultMapper()); + } +} \ No newline at end of file diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/AggregateParamsTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/AggregateParamsTest.java new file mode 100644 index 000000000000..0ba3cce88d8c --- /dev/null +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/AggregateParamsTest.java @@ -0,0 +1,176 @@ +package io.confluent.ksql.execution.streams; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.expression.tree.QualifiedName; +import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.function.TableAggregationFunction; +import io.confluent.ksql.execution.function.udaf.KudafAggregator; +import io.confluent.ksql.execution.function.udaf.KudafInitializer; +import io.confluent.ksql.execution.function.udaf.KudafUndoAggregator; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import java.util.List; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class AggregateParamsTest { + private static final LogicalSchema INPUT_SCHEMA = LogicalSchema.builder() + .valueColumn("REQUIRED0", SqlTypes.BIGINT) + .valueColumn("REQUIRED1", SqlTypes.STRING) + .valueColumn("ARGUMENT0", SqlTypes.INTEGER) + .valueColumn("ARGUMENT1", SqlTypes.DOUBLE) + .build(); + private static final FunctionCall AGG0 = new FunctionCall( + QualifiedName.of("AGG0"), + ImmutableList.of(new QualifiedNameReference(QualifiedName.of("ARGUMENT0"))) + ); + private static final long INITIAL_VALUE0 = 123; + private static final FunctionCall AGG1 = new FunctionCall( + QualifiedName.of("AGG1"), + ImmutableList.of(new QualifiedNameReference(QualifiedName.of("ARGUMENT1"))) + ); + private static final FunctionCall TABLE_AGG = new FunctionCall( + QualifiedName.of("TABLE_AGG"), + ImmutableList.of(new QualifiedNameReference(QualifiedName.of("ARGUMENT0"))) + ); + private static final FunctionCall WINDOW_START = new FunctionCall( + QualifiedName.of("WindowStart"), + ImmutableList.of(new QualifiedNameReference(QualifiedName.of("ARGUMENT0"))) + ); + private static final String INITIAL_VALUE1 = "initial"; + private static final List FUNCTIONS = ImmutableList.of(AGG0, AGG1); + + @Mock + private FunctionRegistry functionRegistry; + @Mock + private KsqlAggregateFunction agg0; + @Mock + private KsqlAggregateFunction agg0Resolved; + @Mock + private KsqlAggregateFunction agg1; + @Mock + private KsqlAggregateFunction agg1Resolved; + @Mock + private TableAggregationFunction tableAgg; + @Mock + private KsqlAggregateFunction windowStart; + + private AggregateParams aggregateParams; + + @Before + @SuppressWarnings("unchecked") + public void init() { + when(functionRegistry.getAggregate(same(AGG0.getName().name()), any())).thenReturn(agg0); + when(agg0Resolved.getInitialValueSupplier()).thenReturn(() -> INITIAL_VALUE0); + when(agg0Resolved.getFunctionName()).thenReturn(AGG0.getName().name()); + when(agg0.getInstance(any())).thenReturn(agg0Resolved); + when(functionRegistry.getAggregate(same(AGG1.getName().name()), any())).thenReturn(agg1); + when(agg1Resolved.getInitialValueSupplier()).thenReturn(() -> INITIAL_VALUE1); + when(agg1Resolved.getFunctionName()).thenReturn(AGG1.getName().name()); + when(agg1.getInstance(any())).thenReturn(agg1Resolved); + when(functionRegistry.getAggregate(same(TABLE_AGG.getName().name()), any())) + .thenReturn(tableAgg); + when(tableAgg.getInitialValueSupplier()).thenReturn(() -> INITIAL_VALUE0); + when(tableAgg.getInstance(any())).thenReturn(tableAgg); + when(functionRegistry.getAggregate(same(WINDOW_START.getName().name()), any())) + .thenReturn(windowStart); + when(windowStart.getInitialValueSupplier()).thenReturn(() -> INITIAL_VALUE0); + when(windowStart.getInstance(any())).thenReturn(windowStart); + when(windowStart.getFunctionName()).thenReturn(WINDOW_START.getName().name()); + aggregateParams = new AggregateParams( + INPUT_SCHEMA, + 2, + functionRegistry, + FUNCTIONS + ); + } + + @Test + public void shouldReturnCorrectAggregator() { + // When: + final KudafAggregator aggregator = aggregateParams.getAggregator(); + + // Then: + assertThat(aggregator.getNonFuncColumnCount(), equalTo(2)); + assertThat( + aggregator.getAggValToAggFunctionMap(), + equalTo(ImmutableMap.of(2, agg0Resolved, 3, agg1Resolved)) + ); + } + + @Test + public void shouldReturnCorrectInitializer() { + // When: + final KudafInitializer initializer = aggregateParams.getInitializer(); + + // Then: + assertThat( + initializer.apply(), + equalTo(new GenericRow(null, null, INITIAL_VALUE0, INITIAL_VALUE1)) + ); + } + + @Test + public void shouldReturnUndoAggregator() { + // Given: + aggregateParams = + new AggregateParams(INPUT_SCHEMA, 2, functionRegistry, ImmutableList.of(TABLE_AGG)); + + // When: + final KudafUndoAggregator undoAggregator = aggregateParams.getUndoAggregator(); + + // Then: + assertThat(undoAggregator.getNonFuncColumnCount(), equalTo(2)); + assertThat( + undoAggregator.getAggValToAggFunctionMap(), + equalTo(ImmutableMap.of(2, tableAgg)) + ); + } + + @Test + public void shouldReturnCorrectWindowSelectMapperForNonWindowSelections() { + // When: + final WindowSelectMapper windowSelectMapper = aggregateParams.getWindowSelectMapper(); + + // Then: + assertThat(windowSelectMapper.hasSelects(), is(false)); + } + + @Test + public void shouldReturnCorrectWindowSelectMapperForWindowSelections() { + // Given: + aggregateParams = new AggregateParams( + INPUT_SCHEMA, + 2, + functionRegistry, + ImmutableList.of(WINDOW_START) + ); + + // When: + final WindowSelectMapper windowSelectMapper = aggregateParams.getWindowSelectMapper(); + final Windowed window = new Windowed<>(null, new TimeWindow(10, 20)); + assertThat( + windowSelectMapper.apply(window, new GenericRow("fiz", "baz", null)), + equalTo(new GenericRow("fiz", "baz", 10)) + ); + } +} diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamAggregateBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamAggregateBuilderTest.java new file mode 100644 index 000000000000..6ed12c2d339d --- /dev/null +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamAggregateBuilderTest.java @@ -0,0 +1,619 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.expression.tree.QualifiedName; +import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.function.udaf.KudafAggregator; +import io.confluent.ksql.execution.function.udaf.KudafInitializer; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; +import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.StreamAggregate; +import io.confluent.ksql.execution.plan.StreamWindowedAggregate; +import io.confluent.ksql.execution.windows.HoppingWindowExpression; +import io.confluent.ksql.execution.windows.SessionWindowExpression; +import io.confluent.ksql.execution.windows.TumblingWindowExpression; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.query.QueryId; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.Format; +import io.confluent.ksql.serde.FormatInfo; +import io.confluent.ksql.serde.KeyFormat; +import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Merger; +import org.apache.kafka.streams.kstream.SessionWindowedKStream; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.TimeWindowedKStream; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.kstream.ValueMapperWithKey; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.Windows; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.WindowStore; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class StreamAggregateBuilderTest { + private static final LogicalSchema INPUT_SCHEMA = LogicalSchema.builder() + .valueColumn("REQUIRED0", SqlTypes.BIGINT) + .valueColumn("REQUIRED1", SqlTypes.STRING) + .valueColumn("ARGUMENT0", SqlTypes.INTEGER) + .valueColumn("ARGUMENT1", SqlTypes.DOUBLE) + .build(); + private static final LogicalSchema AGGREGATE_SCHEMA = LogicalSchema.builder() + .valueColumn("REQUIRED0", SqlTypes.BIGINT) + .valueColumn("REQUIRED1", SqlTypes.STRING) + .valueColumn("RESULT0", SqlTypes.BIGINT) + .valueColumn("RESULT1", SqlTypes.STRING) + .build(); + private static final PhysicalSchema PHYSICAL_AGGREGATE_SCHEMA = PhysicalSchema.from( + AGGREGATE_SCHEMA, + SerdeOption.none() + ); + private static final FunctionCall AGG0 = new FunctionCall( + QualifiedName.of("AGG0"), + ImmutableList.of(new QualifiedNameReference(QualifiedName.of("ARGUMENT0"))) + ); + private static final FunctionCall AGG1 = new FunctionCall( + QualifiedName.of("AGG1"), + ImmutableList.of(new QualifiedNameReference(QualifiedName.of("ARGUMENT1"))) + ); + private static final List FUNCTIONS = ImmutableList.of(AGG0, AGG1); + private static final QueryContext CTX = + new QueryContext.Stacker(new QueryId("qid")).push("agg").push("regate").getQueryContext(); + private static final KeyFormat KEY_FORMAT = KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA)); + private static final ValueFormat VALUE_FORMAT = ValueFormat.of(FormatInfo.of(Format.JSON)); + private static final Duration WINDOW = Duration.ofMillis(30000); + private static final Duration HOP = Duration.ofMillis(10000); + + @Mock + private KGroupedStream groupedStream; + @Mock + private KTable aggregated; + @Mock + private KTable aggregatedWithResults; + @Mock + private TimeWindowedKStream timeWindowedStream; + @Mock + private SessionWindowedKStream sessionWindowedStream; + @Mock + private KTable, GenericRow> windowed; + @Mock + private KTable, GenericRow> windowedWithResults; + @Mock + private KTable, GenericRow> windowedWithWindowBoundaries; + @Mock + private KsqlQueryBuilder queryBuilder; + @Mock + private FunctionRegistry functionRegistry; + @Mock + private AggregateParams.Factory aggregateParamsFactory; + @Mock + private AggregateParams aggregateParams; + @Mock + private KudafInitializer initializer; + @Mock + private KudafAggregator aggregator; + @Mock + private ValueMapper resultMapper; + @Mock + private WindowSelectMapper windowSelectMapper; + @Mock + private Merger merger; + @Mock + private MaterializedFactory materializedFactory; + @Mock + private KeySerde keySerde; + @Mock + private Serde valueSerde; + @Mock + private Materialized> materialized; + @Mock + private Materialized> timeWindowMaterialized; + @Mock + private Materialized> sessionWindowMaterialized; + @Mock + private ExecutionStep> sourceStep; + + private StreamAggregate aggregate; + private StreamWindowedAggregate windowedAggregate; + + @Before + @SuppressWarnings("unchecked") + public void init() { + when(sourceStep.getSchema()).thenReturn(INPUT_SCHEMA); + when(queryBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(valueSerde); + when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); + when(aggregateParamsFactory.create(any(), anyInt(), any(), any())).thenReturn(aggregateParams); + when(aggregateParams.getAggregator()).thenReturn(aggregator); + when(aggregator.getMerger()).thenReturn(merger); + when(aggregator.getResultMapper()).thenReturn(resultMapper); + when(aggregateParams.getInitializer()).thenReturn(initializer); + when(aggregateParams.getWindowSelectMapper()).thenReturn(windowSelectMapper); + when(windowSelectMapper.hasSelects()).thenReturn(false); + } + + @SuppressWarnings("unchecked") + private void givenUnwindowedAggregate() { + when(materializedFactory.>create(any(), any(), any())) + .thenReturn(materialized); + when(groupedStream.aggregate(any(), any(), any(Materialized.class))).thenReturn(aggregated); + when(aggregated.mapValues(any(ValueMapper.class))).thenReturn(aggregatedWithResults); + aggregate = new StreamAggregate( + new DefaultExecutionStepProperties(INPUT_SCHEMA, CTX), + sourceStep, + Formats.of(KEY_FORMAT, VALUE_FORMAT, SerdeOption.none()), + 2, + FUNCTIONS, + AGGREGATE_SCHEMA + ); + } + + @SuppressWarnings("unchecked") + private void givenTimeWindowedAggregate() { + when(materializedFactory.>create(any(), any(), any())) + .thenReturn(timeWindowMaterialized); + when(groupedStream.windowedBy(any(Windows.class))).thenReturn(timeWindowedStream); + when(timeWindowedStream.aggregate(any(), any(), any(Materialized.class))) + .thenReturn(windowed); + when(windowed.mapValues(any(ValueMapper.class))).thenReturn(windowedWithResults); + } + + private void givenTumblingWindowedAggregate() { + givenTimeWindowedAggregate(); + windowedAggregate = new StreamWindowedAggregate( + new DefaultExecutionStepProperties(INPUT_SCHEMA, CTX), + sourceStep, + Formats.of(KEY_FORMAT, VALUE_FORMAT, SerdeOption.none()), + 2, + FUNCTIONS, + AGGREGATE_SCHEMA, + new TumblingWindowExpression(WINDOW.getSeconds(), TimeUnit.SECONDS) + ); + } + + private void givenHoppingWindowedAggregate() { + givenTimeWindowedAggregate(); + windowedAggregate = new StreamWindowedAggregate( + new DefaultExecutionStepProperties(INPUT_SCHEMA, CTX), + sourceStep, + Formats.of(KEY_FORMAT, VALUE_FORMAT, SerdeOption.none()), + 2, + FUNCTIONS, + AGGREGATE_SCHEMA, + new HoppingWindowExpression( + WINDOW.getSeconds(), + TimeUnit.SECONDS, + HOP.getSeconds(), + TimeUnit.SECONDS + ) + ); + } + + @SuppressWarnings("unchecked") + private void givenSessionWindowedAggregate() { + when(materializedFactory.>create(any(), any(), any())) + .thenReturn(sessionWindowMaterialized); + when(groupedStream.windowedBy(any(SessionWindows.class))).thenReturn(sessionWindowedStream); + when(sessionWindowedStream.aggregate(any(), any(), any(), any(Materialized.class))) + .thenReturn(windowed); + when(windowed.mapValues(any(ValueMapper.class))).thenReturn(windowedWithResults); + windowedAggregate = new StreamWindowedAggregate( + new DefaultExecutionStepProperties(INPUT_SCHEMA, CTX), + sourceStep, + Formats.of(KEY_FORMAT, VALUE_FORMAT, SerdeOption.none()), + 2, + FUNCTIONS, + AGGREGATE_SCHEMA, + new SessionWindowExpression(WINDOW.getSeconds(), TimeUnit.SECONDS) + ); + } + + @Test + public void shouldBuildUnwindowedAggregateCorrectly() { + // Given: + givenUnwindowedAggregate(); + + // When: + final KTable result = StreamAggregateBuilder.build( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + assertThat(result, is(aggregatedWithResults)); + final InOrder inOrder = Mockito.inOrder(groupedStream, aggregated, aggregatedWithResults); + inOrder.verify(groupedStream).aggregate(initializer, aggregator, materialized); + inOrder.verify(aggregated).mapValues(resultMapper); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void shouldBuildMaterializedWithCorrectSerdesForUnwindowedAggregate() { + // Given: + givenUnwindowedAggregate(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(materializedFactory).create(same(keySerde), same(valueSerde), any()); + } + + @Test + public void shouldBuildMaterializedWithCorrectNameForUnwindowedAggregate() { + // Given: + givenUnwindowedAggregate(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(materializedFactory).create(any(), any(), eq("agg-regate")); + } + + @Test + public void shouldBuildKeySerdeCorrectlyForUnwindowedAggregate() { + // Given: + givenUnwindowedAggregate(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(queryBuilder).buildKeySerde(KEY_FORMAT.getFormatInfo(), PHYSICAL_AGGREGATE_SCHEMA, CTX); + } + + @Test + public void shouldBuildValueSerdeCorrectlyForUnwindowedAggregate() { + // Given: + givenUnwindowedAggregate(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(queryBuilder).buildValueSerde( + VALUE_FORMAT.getFormatInfo(), + PHYSICAL_AGGREGATE_SCHEMA, + CTX + ); + } + + @Test + public void shouldBuildAggregatorParamsCorrectlyForUnwindowedAggregate() { + // Given: + givenUnwindowedAggregate(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(aggregateParamsFactory).create(INPUT_SCHEMA, 2, functionRegistry, FUNCTIONS); + } + + @Test + public void shouldBuildTumblingWindowedAggregateCorrectly() { + // Given: + givenTumblingWindowedAggregate(); + + // When: + final KTable, GenericRow> result = StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + assertThat(result, is(windowedWithResults)); + final InOrder inOrder = Mockito.inOrder( + groupedStream, + timeWindowedStream, + windowed, + windowedWithResults + ); + inOrder.verify(groupedStream).windowedBy(TimeWindows.of(WINDOW)); + inOrder.verify(timeWindowedStream).aggregate(initializer, aggregator, timeWindowMaterialized); + inOrder.verify(windowed).mapValues(resultMapper); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void shouldBuildHoppingWindowedAggregateCorrectly() { + // Given: + givenHoppingWindowedAggregate(); + + // When: + final KTable, GenericRow> result = StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + assertThat(result, is(windowedWithResults)); + final InOrder inOrder = Mockito.inOrder( + groupedStream, + timeWindowedStream, + windowed, + windowedWithResults + ); + inOrder.verify(groupedStream).windowedBy(TimeWindows.of(WINDOW).advanceBy(HOP)); + inOrder.verify(timeWindowedStream).aggregate(initializer, aggregator, timeWindowMaterialized); + inOrder.verify(windowed).mapValues(resultMapper); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void shouldBuildSessionWindowedAggregateCorrectly() { + // Given: + givenSessionWindowedAggregate(); + + // When: + final KTable, GenericRow> result = StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + assertThat(result, is(windowedWithResults)); + final InOrder inOrder = Mockito.inOrder( + groupedStream, + sessionWindowedStream, + windowed, + windowedWithResults + ); + inOrder.verify(groupedStream).windowedBy(SessionWindows.with(WINDOW)); + inOrder.verify(sessionWindowedStream).aggregate( + initializer, + aggregator, + merger, + sessionWindowMaterialized + ); + inOrder.verify(windowed).mapValues(resultMapper); + inOrder.verifyNoMoreInteractions(); + } + + private List given() { + return ImmutableList.of( + this::givenHoppingWindowedAggregate, + this::givenTumblingWindowedAggregate, + this::givenSessionWindowedAggregate + ); + } + + @Test + public void shouldBuildMaterializedWithCorrectSerdesForWindowedAggregate() { + for (final Runnable given : given()) { + // Given: + reset(groupedStream, timeWindowedStream, sessionWindowedStream, aggregated, materializedFactory); + given.run(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(materializedFactory).create(same(keySerde), same(valueSerde), any()); + } + } + + @Test + public void shouldBuildMaterializedWithCorrectNameForWindowedAggregate() { + for (final Runnable given : given()) { + // Given: + reset(groupedStream, timeWindowedStream, sessionWindowedStream, aggregated, materializedFactory); + given.run(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(materializedFactory).create(any(), any(), eq("agg-regate")); + } + } + + @Test + public void shouldBuildKeySerdeCorrectlyForWindowedAggregate() { + for (final Runnable given : given()) { + // Given: + reset(groupedStream, timeWindowedStream, sessionWindowedStream, aggregated, queryBuilder); + given.run(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(queryBuilder) + .buildKeySerde(KEY_FORMAT.getFormatInfo(), PHYSICAL_AGGREGATE_SCHEMA, CTX); + } + } + + @Test + public void shouldBuildValueSerdeCorrectlyForWindowedAggregate() { + for (final Runnable given : given()) { + // Given: + reset(groupedStream, timeWindowedStream, sessionWindowedStream, aggregated, queryBuilder); + given.run(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(queryBuilder) + .buildValueSerde(VALUE_FORMAT.getFormatInfo(), PHYSICAL_AGGREGATE_SCHEMA, CTX); + } + } + + @Test + public void shouldBuildAggregatorParamsCorrectlyForWindowedAggregate() { + for (final Runnable given : given()) { + // Given: + reset( + groupedStream, + timeWindowedStream, + sessionWindowedStream, + aggregated, + aggregateParamsFactory + ); + when(aggregateParamsFactory.create(any(), anyInt(), any(), any())) + .thenReturn(aggregateParams); + given.run(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(aggregateParamsFactory) + .create(INPUT_SCHEMA, 2, functionRegistry, FUNCTIONS); + } + } + + @Test + @SuppressWarnings("unchecked") + public void shouldAddWindowBoundariesIfSpecified() { + for (final Runnable given : given()) { + // Given: + reset( + groupedStream, timeWindowedStream, sessionWindowedStream, windowed, windowedWithResults); + when(windowSelectMapper.hasSelects()).thenReturn(true); + when(windowedWithResults.mapValues(any(ValueMapperWithKey.class))).thenReturn( + windowedWithWindowBoundaries); + given.run(); + + // When: + final KTable, GenericRow> result = StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + assertThat(result, is(windowedWithWindowBoundaries)); + verify(windowedWithResults).mapValues(windowSelectMapper); + } + } +} \ No newline at end of file diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableAggregateBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableAggregateBuilderTest.java new file mode 100644 index 000000000000..94ac72107cea --- /dev/null +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableAggregateBuilderTest.java @@ -0,0 +1,261 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.expression.tree.QualifiedName; +import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.function.udaf.KudafAggregator; +import io.confluent.ksql.execution.function.udaf.KudafInitializer; +import io.confluent.ksql.execution.function.udaf.KudafUndoAggregator; +import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.TableAggregate; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.query.QueryId; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.Format; +import io.confluent.ksql.serde.FormatInfo; +import io.confluent.ksql.serde.KeyFormat; +import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; +import java.util.List; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.state.KeyValueStore; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class TableAggregateBuilderTest { + private static final LogicalSchema INPUT_SCHEMA = LogicalSchema.builder() + .valueColumn("REQUIRED0", SqlTypes.BIGINT) + .valueColumn("REQUIRED1", SqlTypes.STRING) + .valueColumn("ARGUMENT0", SqlTypes.INTEGER) + .valueColumn("ARGUMENT1", SqlTypes.DOUBLE) + .build(); + private static final LogicalSchema AGGREGATE_SCHEMA = LogicalSchema.builder() + .valueColumn("REQUIRED0", SqlTypes.BIGINT) + .valueColumn("REQUIRED1", SqlTypes.STRING) + .valueColumn("RESULT0", SqlTypes.BIGINT) + .valueColumn("RESULT1", SqlTypes.STRING) + .build(); + private static final PhysicalSchema PHYSICAL_AGGREGATE_SCHEMA = PhysicalSchema.from( + AGGREGATE_SCHEMA, + SerdeOption.none() + ); + private static final FunctionCall AGG0 = new FunctionCall( + QualifiedName.of("AGG0"), + ImmutableList.of(new QualifiedNameReference(QualifiedName.of("ARGUMENT0"))) + ); + private static final FunctionCall AGG1 = new FunctionCall( + QualifiedName.of("AGG1"), + ImmutableList.of(new QualifiedNameReference(QualifiedName.of("ARGUMENT1"))) + ); + private static final List FUNCTIONS = ImmutableList.of(AGG0, AGG1); + private static final QueryContext CTX = + new QueryContext.Stacker(new QueryId("qid")).push("agg").push("regate").getQueryContext(); + private static final KeyFormat KEY_FORMAT = KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA)); + private static final ValueFormat VALUE_FORMAT = ValueFormat.of(FormatInfo.of(Format.JSON)); + + @Mock + private KGroupedTable groupedTable; + @Mock + private KTable aggregated; + @Mock + private KTable aggregatedWithResults; + @Mock + private KsqlQueryBuilder queryBuilder; + @Mock + private FunctionRegistry functionRegistry; + @Mock + private AggregateParams.Factory aggregateParamsFactory; + @Mock + private AggregateParams aggregateParams; + @Mock + private KudafInitializer initializer; + @Mock + private KudafAggregator aggregator; + @Mock + private ValueMapper resultMapper; + @Mock + private KudafUndoAggregator undoAggregator; + @Mock + private MaterializedFactory materializedFactory; + @Mock + private KeySerde keySerde; + @Mock + private Serde valueSerde; + @Mock + private Materialized> materialized; + @Mock + private ExecutionStep> sourceStep; + + private TableAggregate aggregate; + + @Before + @SuppressWarnings("unchecked") + public void init() { + when(sourceStep.getSchema()).thenReturn(INPUT_SCHEMA); + when(queryBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(valueSerde); + when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); + when(aggregateParamsFactory.create(any(), anyInt(), any(), any())).thenReturn(aggregateParams); + when(aggregateParams.getAggregator()).thenReturn(aggregator); + when(aggregateParams.getUndoAggregator()).thenReturn(undoAggregator); + when(aggregateParams.getInitializer()).thenReturn(initializer); + when(aggregator.getResultMapper()).thenReturn(resultMapper); + when(materializedFactory.>create(any(), any(), any())) + .thenReturn(materialized); + when(groupedTable.aggregate(any(), any(), any(), any(Materialized.class))).thenReturn( + aggregated); + when(aggregated.mapValues(any(ValueMapper.class))).thenReturn(aggregatedWithResults); + aggregate = new TableAggregate( + new DefaultExecutionStepProperties(INPUT_SCHEMA, CTX), + sourceStep, + Formats.of(KEY_FORMAT, VALUE_FORMAT, SerdeOption.none()), + 2, + FUNCTIONS, + AGGREGATE_SCHEMA + ); + } + + @Test + public void shouldBuildAggregateCorrectly() { + // When: + final KTable result = TableAggregateBuilder.build( + groupedTable, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + assertThat(result, is(aggregatedWithResults)); + final InOrder inOrder = Mockito.inOrder(groupedTable, aggregated, aggregatedWithResults); + inOrder.verify(groupedTable).aggregate(initializer, aggregator, undoAggregator, materialized); + inOrder.verify(aggregated).mapValues(resultMapper); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void shouldBuildMaterializedWithCorrectSerdesForAggregate() { + // When: + TableAggregateBuilder.build( + groupedTable, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(materializedFactory).create(same(keySerde), same(valueSerde), any()); + } + + @Test + public void shouldBuildMaterializedWithCorrectNameForAggregate() { + // When: + TableAggregateBuilder.build( + groupedTable, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(materializedFactory).create(any(), any(), eq("agg-regate")); + } + + @Test + public void shouldBuildKeySerdeCorrectlyForAggregate() { + // When: + TableAggregateBuilder.build( + groupedTable, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(queryBuilder).buildKeySerde(KEY_FORMAT.getFormatInfo(), PHYSICAL_AGGREGATE_SCHEMA, CTX); + } + + @Test + public void shouldBuildValueSerdeCorrectlyForAggregate() { + // When: + TableAggregateBuilder.build( + groupedTable, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(queryBuilder).buildValueSerde( + VALUE_FORMAT.getFormatInfo(), + PHYSICAL_AGGREGATE_SCHEMA, + CTX + ); + } + + @Test + public void shouldBuildAggregatorParamsCorrectlyForAggregate() { + // When: + TableAggregateBuilder.build( + groupedTable, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(aggregateParamsFactory).create(INPUT_SCHEMA, 2, functionRegistry, FUNCTIONS); + } +} \ No newline at end of file