Skip to content

Commit

Permalink
SD-9457: Update AggregateRel.CopyWithExpressionRewrite to rewrite mea…
Browse files Browse the repository at this point in the history
…sure functions in addition to filters
  • Loading branch information
EpsilonPrime committed Jan 29, 2025
1 parent d6e63d9 commit 7a051b7
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 6 deletions.
5 changes: 5 additions & 0 deletions expr/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,11 @@ func (a *AggregateFunction) IntermediateType() (types.FuncDefArgType, error) {
return a.declaration.Intermediate()
}

// SetArg sets the specified argument to the provided value. The index is not checked for validity.
func (a *AggregateFunction) SetArg(i int, arg types.FuncArg) {
a.args[i] = arg
}

func (a *AggregateFunction) String() string {
var b strings.Builder

Expand Down
1 change: 1 addition & 0 deletions expr/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func (e *ExtensionRegistry) LookupScalarFunction(anchor uint32) (*extensions.Sca
return e.Set.LookupScalarFunction(anchor, e.c)
}

// LookupAggregateFunction returns an AggregateFunctionVariant associated with a previously used function's anchor.
func (e *ExtensionRegistry) LookupAggregateFunction(anchor uint32) (*extensions.AggregateFunctionVariant, bool) {
return e.Set.LookupAggregateFunction(anchor, e.c)
}
Expand Down
30 changes: 28 additions & 2 deletions plan/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,30 @@ func (ar *AggregateRel) Copy(newInputs ...Rel) (Rel, error) {
return &aggregate, nil
}

func (ar *AggregateRel) rewriteAggregateFunc(rewriteFunc RewriteFunc, f *expr.AggregateFunction) (*expr.AggregateFunction, error) {
if f == nil {
return f, nil
}
newF := f
argsAreEqual := true
for i := 0; i < f.NArgs(); i++ {
arg := f.Arg(i)
if exp, ok := arg.(expr.Expression); ok {
var newExp expr.Expression
var err error
if newExp, err = rewriteFunc(exp); err != nil {
return nil, err
}
newF.SetArg(i, newExp)
argsAreEqual = argsAreEqual && exp == newExp
}
}
if argsAreEqual {
return f, nil
}
return newF, nil
}

func (ar *AggregateRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs ...Rel) (Rel, error) {
if len(newInputs) != 1 {
return nil, substraitgo.ErrInvalidInputCount
Expand All @@ -1187,8 +1211,10 @@ func (ar *AggregateRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn
if newMeasures[i].filter, err = rewriteFunc(m.filter); err != nil {
return nil, err
}
measuresAreEqual = measuresAreEqual && newMeasures[i].filter == m.filter
newMeasures[i].measure = m.measure
if newMeasures[i].measure, err = ar.rewriteAggregateFunc(rewriteFunc, m.measure); err != nil {
return nil, err
}
measuresAreEqual = measuresAreEqual && newMeasures[i].filter == m.filter && newMeasures[i].measure == m.measure
}
if groupsAreEqual && measuresAreEqual && newInputs[0] == ar.input {
return ar, nil
Expand Down
23 changes: 19 additions & 4 deletions plan/relations_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package plan

import (
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/v3/extensions"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -26,8 +28,19 @@ func createPrimitiveBool(value bool) expr.Expression {
}

func TestRelations_Copy(t *testing.T) {
aggregateRel := &AggregateRel{input: createVirtualTableReadRel(1), groupingExpressions: []expr.Expression{createPrimitiveFloat(1.0)}, groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{filter: expr.NewPrimitiveLiteral(false, false)}}}
extReg := expr.NewExtensionRegistry(extensions.NewSet(), &extensions.DefaultCollection)
aggregateFnID := extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
Name: "avg",
}
aggregateFn, err := expr.NewAggregateFunc(extReg,
aggregateFnID, nil, types.AggInvocationAll,
types.AggPhaseInitialToResult, nil, createPrimitiveFloat(1.0))
require.NoError(t, err)

aggregateRel := &AggregateRel{input: createVirtualTableReadRel(1),
groupingExpressions: []expr.Expression{createPrimitiveFloat(1.0)}, groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{measure: aggregateFn, filter: expr.NewPrimitiveLiteral(false, false)}}}
crossRel := &CrossRel{left: createVirtualTableReadRel(1), right: createVirtualTableReadRel(2)}
extensionLeafRel := &ExtensionLeafRel{}
extensionMultiRel := &ExtensionMultiRel{inputs: []Rel{createVirtualTableReadRel(1), createVirtualTableReadRel(2)}}
Expand Down Expand Up @@ -91,8 +104,10 @@ func TestRelations_Copy(t *testing.T) {
}
panic("unexpected expression type")
},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(8), groupingExpressions: []expr.Expression{createPrimitiveFloat(9.0)}, groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{filter: expr.NewPrimitiveLiteral(true, false)}}},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(8),
groupingExpressions: []expr.Expression{createPrimitiveFloat(9.0)},
groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{filter: expr.NewPrimitiveLiteral(true, false)}}},
},
{
name: "ExtensionLeafRel Copy with new inputs",
Expand Down

0 comments on commit 7a051b7

Please sign in to comment.