Skip to content

Commit

Permalink
feat: support reading Substrait plans with Window Functions (substrai…
Browse files Browse the repository at this point in the history
…t-io#165)

* fix: incorrect BoundedKind for BoundedWindowBound
* feat: add mapping for sum0 aggregate function
* feat(pojos): convenience methods for working with Window Bounds
  • Loading branch information
vbarua authored Aug 17, 2023
1 parent e211122 commit 50b77e8
Show file tree
Hide file tree
Showing 9 changed files with 350 additions and 57 deletions.
13 changes: 12 additions & 1 deletion core/src/main/java/io/substrait/expression/WindowBound.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ enum Direction {
FOLLOWING
}

public static CurrentRowWindowBound CURRENT_ROW =
ImmutableWindowBound.CurrentRowWindowBound.builder().build();

@Value.Immutable
abstract static class UnboundedWindowBound implements WindowBound {
@Override
Expand All @@ -26,19 +29,27 @@ public BoundedKind boundedKind() {
}

public abstract Direction direction();

public static ImmutableWindowBound.UnboundedWindowBound.Builder builder() {
return ImmutableWindowBound.UnboundedWindowBound.builder();
}
}

@Value.Immutable
abstract static class BoundedWindowBound implements WindowBound {

@Override
public BoundedKind boundedKind() {
return BoundedKind.UNBOUNDED;
return BoundedKind.BOUNDED;
}

public abstract Direction direction();

public abstract Expression offset();

public static ImmutableWindowBound.BoundedWindowBound.Builder builder() {
return ImmutableWindowBound.BoundedWindowBound.builder();
}
}

@Value.Immutable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.expression.proto;

import io.substrait.expression.AbstractExpressionVisitor;
import io.substrait.expression.ExpressionVisitor;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
Expand Down Expand Up @@ -418,7 +419,8 @@ public Expression visit(io.substrait.expression.Expression.Window expr) throws R
expr.partitionBy().stream()
.map(e -> e.accept(this))
.collect(java.util.stream.Collectors.toList());
var builder = Expression.WindowFunction.newBuilder();
var outputType = expr.getType().accept(typeProtoConverter);
var builder = Expression.WindowFunction.newBuilder().setOutputType(outputType);
if (expr.hasNormalAggregateFunction()) {
var aggMeasureFunc = expr.aggregateFunction().getFunction();
var funcReference = extensionCollector.getFunctionReference(aggMeasureFunc.declaration());
Expand All @@ -427,18 +429,22 @@ public Expression visit(io.substrait.expression.Expression.Window expr) throws R
aggMeasureFunc.arguments().stream()
.map(a -> a.accept(aggMeasureFunc.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList());
var ordinal = aggMeasureFunc.aggregationPhase().ordinal();
builder.setFunctionReference(funcReference).setPhaseValue(ordinal).addAllArguments(args);
builder
.setFunctionReference(funcReference)
.setPhase(aggMeasureFunc.aggregationPhase().toProto())
.addAllArguments(args);
} else {
var windowFunc = expr.windowFunction().getFunction();
var funcReference = extensionCollector.getFunctionReference(windowFunc.declaration());
var ordinal = windowFunc.aggregationPhase().ordinal();
var argVisitor = FunctionArg.toProto(typeProtoConverter, this);
var args =
windowFunc.arguments().stream()
.map(a -> a.accept(windowFunc.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList());
builder.setFunctionReference(funcReference).setPhaseValue(ordinal).addAllArguments(args);
builder
.setFunctionReference(funcReference)
.setPhase(windowFunc.aggregationPhase().toProto())
.addAllArguments(args);
}
var sortFields =
expr.orderBy().stream()
Expand All @@ -463,57 +469,68 @@ public Expression visit(io.substrait.expression.Expression.Window expr) throws R
.build();
}

private static class LiteralToWindowBoundOffset
extends AbstractExpressionVisitor<Long, RuntimeException> {

@Override
public Long visitFallback(io.substrait.expression.Expression expr) {
throw new RuntimeException(
String.format("Expected positive integer for Window Bound offset, received: %s", expr));
}

private static long offsetIsPositive(long offset) {
if (offset >= 1) {
return offset;
}
throw new RuntimeException(
String.format("Expected positive offset for Window Bound offset, recieved: %d", offset));
}

@Override
public Long visit(io.substrait.expression.Expression.I8Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}

@Override
public Long visit(io.substrait.expression.Expression.I16Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}

@Override
public Long visit(io.substrait.expression.Expression.I32Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}

@Override
public Long visit(io.substrait.expression.Expression.I64Literal expr) throws RuntimeException {
return offsetIsPositive(expr.value());
}
}

private Expression.WindowFunction.Bound toBound(io.substrait.expression.WindowBound windowBound) {
var boundedKind = windowBound.boundedKind();
Expression.WindowFunction.Bound bound = null;
switch (boundedKind) {
case CURRENT_ROW -> bound =
Expression.WindowFunction.Bound.newBuilder()
.setCurrentRow(Expression.WindowFunction.Bound.CurrentRow.getDefaultInstance())
.build();
return switch (boundedKind) {
case CURRENT_ROW -> Expression.WindowFunction.Bound.newBuilder()
.setCurrentRow(Expression.WindowFunction.Bound.CurrentRow.getDefaultInstance())
.build();
case BOUNDED -> {
WindowBound.BoundedWindowBound boundedWindowBound =
(WindowBound.BoundedWindowBound) windowBound;
var offset = boundedWindowBound.offset();
boolean isPreceding = boundedWindowBound.direction() == WindowBound.Direction.PRECEDING;
io.substrait.expression.Expression.I32Literal offsetLiteral =
(io.substrait.expression.Expression.I32Literal) offset;
var offsetVal = offsetLiteral.value();
var boundedProto = Expression.WindowFunction.Bound.Unbounded.getDefaultInstance();
if (isPreceding) {
var offsetProto =
Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(offsetVal).build();
bound = Expression.WindowFunction.Bound.newBuilder().setPreceding(offsetProto).build();
} else {
var offsetProto =
Expression.WindowFunction.Bound.Following.newBuilder().setOffset(offsetVal).build();
bound = Expression.WindowFunction.Bound.newBuilder().setFollowing(offsetProto).build();
}
}
case UNBOUNDED -> {
WindowBound.UnboundedWindowBound unboundedWindowBound =
(WindowBound.UnboundedWindowBound) windowBound;
boolean isPreceding = unboundedWindowBound.direction() == WindowBound.Direction.PRECEDING;
var unboundedProto = Expression.WindowFunction.Bound.Unbounded.getDefaultInstance();
if (isPreceding) {
var preceding = Expression.WindowFunction.Bound.Preceding.newBuilder().build();
bound =
Expression.WindowFunction.Bound.newBuilder()
.setUnbounded(unboundedProto)
.setPreceding(preceding)
.build();
} else {
var following = Expression.WindowFunction.Bound.Following.newBuilder().build();
bound =
Expression.WindowFunction.Bound.newBuilder()
.setUnbounded(unboundedProto)
.setFollowing(following)
.build();
}
var offset = boundedWindowBound.offset().accept(new LiteralToWindowBoundOffset());
yield switch (boundedWindowBound.direction()) {
case PRECEDING -> Expression.WindowFunction.Bound.newBuilder()
.setPreceding(
Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(offset))
.build();
case FOLLOWING -> Expression.WindowFunction.Bound.newBuilder()
.setFollowing(
Expression.WindowFunction.Bound.Following.newBuilder().setOffset(offset))
.build();
};
}
default -> throw new RuntimeException(
String.format("Unexpected Expression.WindowFunction.Bound enum:%s", boundedKind));
}
return bound;
case UNBOUNDED -> Expression.WindowFunction.Bound.newBuilder()
.setUnbounded(Expression.WindowFunction.Bound.Unbounded.getDefaultInstance())
.build();
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.ImmutableExpression;
import io.substrait.expression.WindowBound;
import io.substrait.expression.WindowFunctionInvocation;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.ImmutableSimpleExtension;
import io.substrait.extension.SimpleExtension;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.type.Type;
Expand Down Expand Up @@ -110,7 +113,73 @@ public Expression from(io.substrait.proto.Expression expr) {
yield ImmutableExpression.ScalarFunctionInvocation.builder()
.addAllArguments(args)
.declaration(declaration)
.outputType(protoTypeConverter.from(expr.getScalarFunction().getOutputType()))
.outputType(protoTypeConverter.from(scalarFunction.getOutputType()))
.build();
}
case WINDOW_FUNCTION -> {
var windowFunction = expr.getWindowFunction();
var functionReference = windowFunction.getFunctionReference();
SimpleExtension.WindowFunctionVariant functionVariant;
try {
functionVariant = lookup.getWindowFunction(functionReference, extensions);
} catch (RuntimeException e) {
// TODO: Ideally we shouldn't need to catch a RuntimeException to be able to attempt our
// second lookup
var aggFunctionVariant = lookup.getAggregateFunction(functionReference, extensions);
functionVariant =
ImmutableSimpleExtension.WindowFunctionVariant.builder()
// Sets all fields declared in the Function interface
.from(aggFunctionVariant)
// Set WindowFunctionVariant fields
.decomposability(aggFunctionVariant.decomposability())
.intermediate(aggFunctionVariant.intermediate())
// Aggregate Functions used in Windows have WindowType Streaming
.windowType(SimpleExtension.WindowType.STREAMING)
.build();
}
final SimpleExtension.WindowFunctionVariant declaration = functionVariant;

var pF = new FunctionArg.ProtoFrom(this, protoTypeConverter);
var args =
IntStream.range(0, windowFunction.getArgumentsCount())
.mapToObj(i -> pF.convert(declaration, i, windowFunction.getArguments(i)))
.collect(java.util.stream.Collectors.toList());
var partitionExprs =
windowFunction.getPartitionsList().stream()
.map(this::from)
.collect(java.util.stream.Collectors.toList());
var sortFields =
windowFunction.getSortsList().stream()
.map(
s ->
Expression.SortField.builder()
.direction(Expression.SortDirection.fromProto(s.getDirection()))
.expr(from(s.getExpr()))
.build())
.collect(java.util.stream.Collectors.toList());
var wfi =
WindowFunctionInvocation.builder()
.addAllArguments(args)
.declaration(declaration)
.outputType(protoTypeConverter.from(windowFunction.getOutputType()))
.aggregationPhase(Expression.AggregationPhase.fromProto(windowFunction.getPhase()))
.addAllSort(sortFields)
.invocation(
Expression.AggregationInvocation.fromProto(windowFunction.getInvocation()))
.build();

WindowBound lowerBound = toLowerBound(windowFunction.getLowerBound());
WindowBound upperBound = toUpperBound(windowFunction.getUpperBound());

var wf = ImmutableExpression.WindowFunction.builder().function(wfi).build();
yield Expression.Window.builder()
.windowFunction(wf)
.hasNormalAggregateFunction(false)
.type(protoTypeConverter.from(windowFunction.getOutputType()))
.partitionBy(partitionExprs)
.orderBy(sortFields)
.lowerBound(lowerBound)
.upperBound(upperBound)
.build();
}
case IF_THEN -> {
Expand Down Expand Up @@ -200,13 +269,51 @@ public Expression from(io.substrait.proto.Expression expr) {
}
}

// TODO window, enum.
case WINDOW_FUNCTION, ENUM -> throw new UnsupportedOperationException(
// TODO enum.
case ENUM -> throw new UnsupportedOperationException(
"Unsupported type: " + expr.getRexTypeCase());
default -> throw new IllegalArgumentException("Unknown type: " + expr.getRexTypeCase());
};
}

private WindowBound toLowerBound(io.substrait.proto.Expression.WindowFunction.Bound bound) {
return toWindowBound(
bound,
WindowBound.UnboundedWindowBound.builder()
.direction(WindowBound.Direction.PRECEDING)
.build());
}

private WindowBound toUpperBound(io.substrait.proto.Expression.WindowFunction.Bound bound) {
return toWindowBound(
bound,
WindowBound.UnboundedWindowBound.builder()
.direction(WindowBound.Direction.FOLLOWING)
.build());
}

private WindowBound toWindowBound(
io.substrait.proto.Expression.WindowFunction.Bound bound, WindowBound defaultBound) {
return switch (bound.getKindCase()) {
case PRECEDING -> WindowBound.BoundedWindowBound.builder()
.direction(WindowBound.Direction.PRECEDING)
.offset(
Expression.Literal.I64Literal.builder()
.value(bound.getPreceding().getOffset())
.build())
.build();
case FOLLOWING -> WindowBound.BoundedWindowBound.builder()
.direction(WindowBound.Direction.FOLLOWING)
.offset(
Expression.Literal.I64Literal.builder()
.value(bound.getFollowing().getOffset())
.build())
.build();
case CURRENT_ROW -> WindowBound.CURRENT_ROW;
case UNBOUNDED, KIND_NOT_SET -> defaultBound;
};
}

public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
return switch (literal.getLiteralTypeCase()) {
case BOOLEAN -> ExpressionCreator.bool(literal.getNullable(), literal.getBoolean());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ public SimpleExtension.ScalarFunctionVariant getScalarFunction(
return extensions.getScalarFunction(anchor);
}

public SimpleExtension.WindowFunctionVariant getWindowFunction(
int reference, SimpleExtension.ExtensionCollection extensions) {
var anchor = functionAnchorMap.get(reference);
if (anchor == null) {
throw new IllegalArgumentException(
"Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan.");
}

return extensions.getWindowFunction(anchor);
}

public SimpleExtension.AggregateFunctionVariant getAggregateFunction(
int reference, SimpleExtension.ExtensionCollection extensions) {
var anchor = functionAnchorMap.get(reference);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ public interface ExtensionLookup {
SimpleExtension.ScalarFunctionVariant getScalarFunction(
int reference, SimpleExtension.ExtensionCollection extensions);

SimpleExtension.WindowFunctionVariant getWindowFunction(
int reference, SimpleExtension.ExtensionCollection extensions);

SimpleExtension.AggregateFunctionVariant getAggregateFunction(
int reference, SimpleExtension.ExtensionCollection extensions);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ enum Decomposability {
MANY
}

enum WindowType {
public enum WindowType {
PARTITION,
STREAMING
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public class FunctionMappings {
s(SqlStdOperatorTable.MIN, "min"),
s(SqlStdOperatorTable.MAX, "max"),
s(SqlStdOperatorTable.SUM, "sum"),
s(SqlStdOperatorTable.SUM0, "sum0"),
s(SqlStdOperatorTable.COUNT, "count"),
s(SqlStdOperatorTable.APPROX_COUNT_DISTINCT, "approx_count_distinct"),
s(SqlStdOperatorTable.AVG, "avg"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ private Optional<WindowFunctionInvocation> findWindowFunctionInvocation(
private WindowBound toWindowBound(
RexWindowBound rexWindowBound, RexExpressionConverter rexExpressionConverter) {
if (rexWindowBound.isCurrentRow()) {
return ImmutableWindowBound.CurrentRowWindowBound.builder().build();
return WindowBound.CURRENT_ROW;
}
if (rexWindowBound.isUnbounded()) {
var direction = findWindowBoundDirection(rexWindowBound);
Expand Down
Loading

0 comments on commit 50b77e8

Please sign in to comment.