diff --git a/core/src/main/java/io/substrait/expression/WindowBound.java b/core/src/main/java/io/substrait/expression/WindowBound.java index 2273eb8db..15724c7b3 100644 --- a/core/src/main/java/io/substrait/expression/WindowBound.java +++ b/core/src/main/java/io/substrait/expression/WindowBound.java @@ -18,6 +18,9 @@ enum Direction { FOLLOWING } + public static CurrentRowWindowBound CURRENT_ROW = + ImmutableWindowBound.CurrentRowWindowBound.builder().build(); + @Value.Immutable abstract static class UnboundedWindowBound implements WindowBound { @Override @@ -26,6 +29,10 @@ public BoundedKind boundedKind() { } public abstract Direction direction(); + + public static ImmutableWindowBound.UnboundedWindowBound.Builder builder() { + return ImmutableWindowBound.UnboundedWindowBound.builder(); + } } @Value.Immutable @@ -33,12 +40,16 @@ abstract static class BoundedWindowBound implements WindowBound { @Override public BoundedKind boundedKind() { - return BoundedKind.UNBOUNDED; + return BoundedKind.BOUNDED; } public abstract Direction direction(); public abstract Expression offset(); + + public static ImmutableWindowBound.BoundedWindowBound.Builder builder() { + return ImmutableWindowBound.BoundedWindowBound.builder(); + } } @Value.Immutable diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 7d85c34ed..6ecbe3729 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -1,5 +1,6 @@ package io.substrait.expression.proto; +import io.substrait.expression.AbstractExpressionVisitor; import io.substrait.expression.ExpressionVisitor; import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; @@ -418,7 +419,8 @@ public Expression visit(io.substrait.expression.Expression.Window expr) throws R expr.partitionBy().stream() .map(e -> e.accept(this)) .collect(java.util.stream.Collectors.toList()); - var builder = Expression.WindowFunction.newBuilder(); + var outputType = expr.getType().accept(typeProtoConverter); + var builder = Expression.WindowFunction.newBuilder().setOutputType(outputType); if (expr.hasNormalAggregateFunction()) { var aggMeasureFunc = expr.aggregateFunction().getFunction(); var funcReference = extensionCollector.getFunctionReference(aggMeasureFunc.declaration()); @@ -427,18 +429,22 @@ public Expression visit(io.substrait.expression.Expression.Window expr) throws R aggMeasureFunc.arguments().stream() .map(a -> a.accept(aggMeasureFunc.declaration(), 0, argVisitor)) .collect(java.util.stream.Collectors.toList()); - var ordinal = aggMeasureFunc.aggregationPhase().ordinal(); - builder.setFunctionReference(funcReference).setPhaseValue(ordinal).addAllArguments(args); + builder + .setFunctionReference(funcReference) + .setPhase(aggMeasureFunc.aggregationPhase().toProto()) + .addAllArguments(args); } else { var windowFunc = expr.windowFunction().getFunction(); var funcReference = extensionCollector.getFunctionReference(windowFunc.declaration()); - var ordinal = windowFunc.aggregationPhase().ordinal(); var argVisitor = FunctionArg.toProto(typeProtoConverter, this); var args = windowFunc.arguments().stream() .map(a -> a.accept(windowFunc.declaration(), 0, argVisitor)) .collect(java.util.stream.Collectors.toList()); - builder.setFunctionReference(funcReference).setPhaseValue(ordinal).addAllArguments(args); + builder + .setFunctionReference(funcReference) + .setPhase(windowFunc.aggregationPhase().toProto()) + .addAllArguments(args); } var sortFields = expr.orderBy().stream() @@ -463,57 +469,68 @@ public Expression visit(io.substrait.expression.Expression.Window expr) throws R .build(); } + private static class LiteralToWindowBoundOffset + extends AbstractExpressionVisitor { + + @Override + public Long visitFallback(io.substrait.expression.Expression expr) { + throw new RuntimeException( + String.format("Expected positive integer for Window Bound offset, received: %s", expr)); + } + + private static long offsetIsPositive(long offset) { + if (offset >= 1) { + return offset; + } + throw new RuntimeException( + String.format("Expected positive offset for Window Bound offset, recieved: %d", offset)); + } + + @Override + public Long visit(io.substrait.expression.Expression.I8Literal expr) throws RuntimeException { + return offsetIsPositive(expr.value()); + } + + @Override + public Long visit(io.substrait.expression.Expression.I16Literal expr) throws RuntimeException { + return offsetIsPositive(expr.value()); + } + + @Override + public Long visit(io.substrait.expression.Expression.I32Literal expr) throws RuntimeException { + return offsetIsPositive(expr.value()); + } + + @Override + public Long visit(io.substrait.expression.Expression.I64Literal expr) throws RuntimeException { + return offsetIsPositive(expr.value()); + } + } + private Expression.WindowFunction.Bound toBound(io.substrait.expression.WindowBound windowBound) { var boundedKind = windowBound.boundedKind(); - Expression.WindowFunction.Bound bound = null; - switch (boundedKind) { - case CURRENT_ROW -> bound = - Expression.WindowFunction.Bound.newBuilder() - .setCurrentRow(Expression.WindowFunction.Bound.CurrentRow.getDefaultInstance()) - .build(); + return switch (boundedKind) { + case CURRENT_ROW -> Expression.WindowFunction.Bound.newBuilder() + .setCurrentRow(Expression.WindowFunction.Bound.CurrentRow.getDefaultInstance()) + .build(); case BOUNDED -> { WindowBound.BoundedWindowBound boundedWindowBound = (WindowBound.BoundedWindowBound) windowBound; - var offset = boundedWindowBound.offset(); - boolean isPreceding = boundedWindowBound.direction() == WindowBound.Direction.PRECEDING; - io.substrait.expression.Expression.I32Literal offsetLiteral = - (io.substrait.expression.Expression.I32Literal) offset; - var offsetVal = offsetLiteral.value(); - var boundedProto = Expression.WindowFunction.Bound.Unbounded.getDefaultInstance(); - if (isPreceding) { - var offsetProto = - Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(offsetVal).build(); - bound = Expression.WindowFunction.Bound.newBuilder().setPreceding(offsetProto).build(); - } else { - var offsetProto = - Expression.WindowFunction.Bound.Following.newBuilder().setOffset(offsetVal).build(); - bound = Expression.WindowFunction.Bound.newBuilder().setFollowing(offsetProto).build(); - } - } - case UNBOUNDED -> { - WindowBound.UnboundedWindowBound unboundedWindowBound = - (WindowBound.UnboundedWindowBound) windowBound; - boolean isPreceding = unboundedWindowBound.direction() == WindowBound.Direction.PRECEDING; - var unboundedProto = Expression.WindowFunction.Bound.Unbounded.getDefaultInstance(); - if (isPreceding) { - var preceding = Expression.WindowFunction.Bound.Preceding.newBuilder().build(); - bound = - Expression.WindowFunction.Bound.newBuilder() - .setUnbounded(unboundedProto) - .setPreceding(preceding) - .build(); - } else { - var following = Expression.WindowFunction.Bound.Following.newBuilder().build(); - bound = - Expression.WindowFunction.Bound.newBuilder() - .setUnbounded(unboundedProto) - .setFollowing(following) - .build(); - } + var offset = boundedWindowBound.offset().accept(new LiteralToWindowBoundOffset()); + yield switch (boundedWindowBound.direction()) { + case PRECEDING -> Expression.WindowFunction.Bound.newBuilder() + .setPreceding( + Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(offset)) + .build(); + case FOLLOWING -> Expression.WindowFunction.Bound.newBuilder() + .setFollowing( + Expression.WindowFunction.Bound.Following.newBuilder().setOffset(offset)) + .build(); + }; } - default -> throw new RuntimeException( - String.format("Unexpected Expression.WindowFunction.Bound enum:%s", boundedKind)); - } - return bound; + case UNBOUNDED -> Expression.WindowFunction.Bound.newBuilder() + .setUnbounded(Expression.WindowFunction.Bound.Unbounded.getDefaultInstance()) + .build(); + }; } } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index a49b7b74b..34ce6bb96 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -5,7 +5,10 @@ import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; import io.substrait.expression.ImmutableExpression; +import io.substrait.expression.WindowBound; +import io.substrait.expression.WindowFunctionInvocation; import io.substrait.extension.ExtensionLookup; +import io.substrait.extension.ImmutableSimpleExtension; import io.substrait.extension.SimpleExtension; import io.substrait.relation.ProtoRelConverter; import io.substrait.type.Type; @@ -110,7 +113,73 @@ public Expression from(io.substrait.proto.Expression expr) { yield ImmutableExpression.ScalarFunctionInvocation.builder() .addAllArguments(args) .declaration(declaration) - .outputType(protoTypeConverter.from(expr.getScalarFunction().getOutputType())) + .outputType(protoTypeConverter.from(scalarFunction.getOutputType())) + .build(); + } + case WINDOW_FUNCTION -> { + var windowFunction = expr.getWindowFunction(); + var functionReference = windowFunction.getFunctionReference(); + SimpleExtension.WindowFunctionVariant functionVariant; + try { + functionVariant = lookup.getWindowFunction(functionReference, extensions); + } catch (RuntimeException e) { + // TODO: Ideally we shouldn't need to catch a RuntimeException to be able to attempt our + // second lookup + var aggFunctionVariant = lookup.getAggregateFunction(functionReference, extensions); + functionVariant = + ImmutableSimpleExtension.WindowFunctionVariant.builder() + // Sets all fields declared in the Function interface + .from(aggFunctionVariant) + // Set WindowFunctionVariant fields + .decomposability(aggFunctionVariant.decomposability()) + .intermediate(aggFunctionVariant.intermediate()) + // Aggregate Functions used in Windows have WindowType Streaming + .windowType(SimpleExtension.WindowType.STREAMING) + .build(); + } + final SimpleExtension.WindowFunctionVariant declaration = functionVariant; + + var pF = new FunctionArg.ProtoFrom(this, protoTypeConverter); + var args = + IntStream.range(0, windowFunction.getArgumentsCount()) + .mapToObj(i -> pF.convert(declaration, i, windowFunction.getArguments(i))) + .collect(java.util.stream.Collectors.toList()); + var partitionExprs = + windowFunction.getPartitionsList().stream() + .map(this::from) + .collect(java.util.stream.Collectors.toList()); + var sortFields = + windowFunction.getSortsList().stream() + .map( + s -> + Expression.SortField.builder() + .direction(Expression.SortDirection.fromProto(s.getDirection())) + .expr(from(s.getExpr())) + .build()) + .collect(java.util.stream.Collectors.toList()); + var wfi = + WindowFunctionInvocation.builder() + .addAllArguments(args) + .declaration(declaration) + .outputType(protoTypeConverter.from(windowFunction.getOutputType())) + .aggregationPhase(Expression.AggregationPhase.fromProto(windowFunction.getPhase())) + .addAllSort(sortFields) + .invocation( + Expression.AggregationInvocation.fromProto(windowFunction.getInvocation())) + .build(); + + WindowBound lowerBound = toLowerBound(windowFunction.getLowerBound()); + WindowBound upperBound = toUpperBound(windowFunction.getUpperBound()); + + var wf = ImmutableExpression.WindowFunction.builder().function(wfi).build(); + yield Expression.Window.builder() + .windowFunction(wf) + .hasNormalAggregateFunction(false) + .type(protoTypeConverter.from(windowFunction.getOutputType())) + .partitionBy(partitionExprs) + .orderBy(sortFields) + .lowerBound(lowerBound) + .upperBound(upperBound) .build(); } case IF_THEN -> { @@ -200,13 +269,51 @@ public Expression from(io.substrait.proto.Expression expr) { } } - // TODO window, enum. - case WINDOW_FUNCTION, ENUM -> throw new UnsupportedOperationException( + // TODO enum. + case ENUM -> throw new UnsupportedOperationException( "Unsupported type: " + expr.getRexTypeCase()); default -> throw new IllegalArgumentException("Unknown type: " + expr.getRexTypeCase()); }; } + private WindowBound toLowerBound(io.substrait.proto.Expression.WindowFunction.Bound bound) { + return toWindowBound( + bound, + WindowBound.UnboundedWindowBound.builder() + .direction(WindowBound.Direction.PRECEDING) + .build()); + } + + private WindowBound toUpperBound(io.substrait.proto.Expression.WindowFunction.Bound bound) { + return toWindowBound( + bound, + WindowBound.UnboundedWindowBound.builder() + .direction(WindowBound.Direction.FOLLOWING) + .build()); + } + + private WindowBound toWindowBound( + io.substrait.proto.Expression.WindowFunction.Bound bound, WindowBound defaultBound) { + return switch (bound.getKindCase()) { + case PRECEDING -> WindowBound.BoundedWindowBound.builder() + .direction(WindowBound.Direction.PRECEDING) + .offset( + Expression.Literal.I64Literal.builder() + .value(bound.getPreceding().getOffset()) + .build()) + .build(); + case FOLLOWING -> WindowBound.BoundedWindowBound.builder() + .direction(WindowBound.Direction.FOLLOWING) + .offset( + Expression.Literal.I64Literal.builder() + .value(bound.getFollowing().getOffset()) + .build()) + .build(); + case CURRENT_ROW -> WindowBound.CURRENT_ROW; + case UNBOUNDED, KIND_NOT_SET -> defaultBound; + }; + } + public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { return switch (literal.getLiteralTypeCase()) { case BOOLEAN -> ExpressionCreator.bool(literal.getNullable(), literal.getBoolean()); diff --git a/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java b/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java index 5e2076209..51aa38ebd 100644 --- a/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java @@ -24,6 +24,17 @@ public SimpleExtension.ScalarFunctionVariant getScalarFunction( return extensions.getScalarFunction(anchor); } + public SimpleExtension.WindowFunctionVariant getWindowFunction( + int reference, SimpleExtension.ExtensionCollection extensions) { + var anchor = functionAnchorMap.get(reference); + if (anchor == null) { + throw new IllegalArgumentException( + "Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan."); + } + + return extensions.getWindowFunction(anchor); + } + public SimpleExtension.AggregateFunctionVariant getAggregateFunction( int reference, SimpleExtension.ExtensionCollection extensions) { var anchor = functionAnchorMap.get(reference); diff --git a/core/src/main/java/io/substrait/extension/ExtensionLookup.java b/core/src/main/java/io/substrait/extension/ExtensionLookup.java index c7f03c5ee..ada3af0dd 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ExtensionLookup.java @@ -9,6 +9,9 @@ public interface ExtensionLookup { SimpleExtension.ScalarFunctionVariant getScalarFunction( int reference, SimpleExtension.ExtensionCollection extensions); + SimpleExtension.WindowFunctionVariant getWindowFunction( + int reference, SimpleExtension.ExtensionCollection extensions); + SimpleExtension.AggregateFunctionVariant getAggregateFunction( int reference, SimpleExtension.ExtensionCollection extensions); diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index 657ee8644..fbc7b8e96 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -61,7 +61,7 @@ enum Decomposability { MANY } - enum WindowType { + public enum WindowType { PARTITION, STREAMING } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 389e4bc53..1d7cabc0c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -72,6 +72,7 @@ public class FunctionMappings { s(SqlStdOperatorTable.MIN, "min"), s(SqlStdOperatorTable.MAX, "max"), s(SqlStdOperatorTable.SUM, "sum"), + s(SqlStdOperatorTable.SUM0, "sum0"), s(SqlStdOperatorTable.COUNT, "count"), s(SqlStdOperatorTable.APPROX_COUNT_DISTINCT, "approx_count_distinct"), s(SqlStdOperatorTable.AVG, "avg")) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java index 4540d8bb9..5b7f08ad1 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java @@ -175,7 +175,7 @@ private Optional findWindowFunctionInvocation( private WindowBound toWindowBound( RexWindowBound rexWindowBound, RexExpressionConverter rexExpressionConverter) { if (rexWindowBound.isCurrentRow()) { - return ImmutableWindowBound.CurrentRowWindowBound.builder().build(); + return WindowBound.CURRENT_ROW; } if (rexWindowBound.isUnbounded()) { var direction = findWindowBoundDirection(rexWindowBound); diff --git a/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java new file mode 100644 index 000000000..486678155 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java @@ -0,0 +1,143 @@ +package io.substrait.isthmus; + +import java.io.IOException; +import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +public class WindowFunctionTest extends PlanTestBase { + + @Nested + class WindowFunctionInvocations { + + @Test + void rowNumber() throws IOException, SqlParseException { + assertProtoPlanRoundrip("select O_ORDERKEY, row_number() over () from ORDERS"); + } + + @ParameterizedTest + @ValueSource(strings = {"rank", "dense_rank", "percent_rank"}) + void rankFunctions(String rankFunction) throws IOException, SqlParseException { + var query = + String.format( + "select O_ORDERKEY, %s() over (order by O_SHIPPRIORITY) from ORDERS", rankFunction); + assertProtoPlanRoundrip(query); + } + + @ParameterizedTest + @ValueSource(strings = {"rank", "dense_rank", "percent_rank"}) + void rankFunctionsWithPartitions(String rankFunction) throws IOException, SqlParseException { + var query = + String.format( + "select O_ORDERKEY, %s() over (partition by O_CUSTKEY order by O_SHIPPRIORITY) from ORDERS", + rankFunction); + assertProtoPlanRoundrip(query); + } + + @Test + void cumeDist() throws IOException, SqlParseException { + assertProtoPlanRoundrip( + "select O_ORDERKEY, cume_dist() over (order by O_SHIPPRIORITY) from ORDERS"); + } + + @Test + @Disabled + void ntile() throws IOException, SqlParseException { + // TODO: The WindowFunctionConverter has some assumptions about function arguments that need + // to be addressed for this to work. + assertProtoPlanRoundrip("select O_ORDERKEY, ntile(4) over () from ORDERS"); + } + } + + @Nested + class BoundRoundTripping { + // Calcite is clever and will elide bounds if they are not needed. The following test queries + // are such that bounds will be included to better verify round-tripping. + // + // Plan summaries are included to show that bounds are included. They were generated using the + // static RelOptUtil.toString(RelNode rel) method with a debugger. + + @Test + void unbounded() throws IOException, SqlParseException { + /* + LogicalProject(EXPR$0=[MAX($7) OVER ()]) + LogicalTableScan(table=[[ORDERS]]) + */ + assertProtoPlanRoundrip("select max(O_SHIPPRIORITY) over () from ORDERS"); + } + + @Test + void unboundedPreceding() throws IOException, SqlParseException { + /* + LogicalProject(EXPR$0=[MIN($7) OVER (PARTITION BY $1 ORDER BY $4 ROWS UNBOUNDED PRECEDING)]) + LogicalTableScan(table=[[ORDERS]]) + */ + var overClause = "partition by O_CUSTKEY order by O_ORDERDATE rows unbounded preceding"; + assertProtoPlanRoundrip( + String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); + } + + @Test + void unboundedFollowing() throws IOException, SqlParseException { + /* + LogicalProject(EXPR$0=[MAX($7) OVER (PARTITION BY $1 ORDER BY $4 ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)]) + LogicalTableScan(table=[[ORDERS]]) + */ + var overClaus = + "partition by O_CUSTKEY order by O_ORDERDATE rows between current row AND unbounded following"; + assertProtoPlanRoundrip( + String.format("select max(O_SHIPPRIORITY) over (%s) from ORDERS", overClaus)); + } + + @Test + void rowsPrecedingToCurrent() throws IOException, SqlParseException { + /* + LogicalProject(EXPR$0=[MIN($7) OVER (PARTITION BY $1 ORDER BY $4 ROWS 1 PRECEDING)]) + LogicalTableScan(table=[[ORDERS]]) + */ + var overClause = + "partition by O_CUSTKEY order by O_ORDERDATE rows between 1 preceding and current row"; + assertProtoPlanRoundrip( + String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); + } + + @Test + void currentToRowsFollowing() throws IOException, SqlParseException { + /* + LogicalProject(EXPR$0=[MAX($7) OVER (PARTITION BY $1 ORDER BY $4 ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING)]) + LogicalTableScan(table=[[ORDERS]]) + */ + var overClause = + "partition by O_CUSTKEY order by O_ORDERDATE rows between current row and 2 following"; + assertProtoPlanRoundrip( + String.format("select max(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); + } + + @Test + void rowsPrecedingAndFollowing() throws IOException, SqlParseException { + /* + LogicalProject(EXPR$0=[MIN($7) OVER (PARTITION BY $1 ORDER BY $4 ROWS BETWEEN 3 PRECEDING AND 4 FOLLOWING)]) + LogicalTableScan(table=[[ORDERS]]) + */ + var overClause = + "partition by O_CUSTKEY order by O_ORDERDATE rows between 3 preceding and 4 following"; + assertProtoPlanRoundrip( + String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); + } + } + + @Nested + class AggregateFunctionInvocations { + + @ParameterizedTest + @ValueSource(strings = {"avg", "count", "max", "min", "sum"}) + void standardAggregateFunctions(String aggFunction) throws SqlParseException, IOException { + assertProtoPlanRoundrip( + String.format( + "select %s(L_LINENUMBER) over (partition BY L_PARTKEY) from lineitem", aggFunction)); + } + } +}