Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update AggregateRel.CopyWithExpressionRewrite to rewrite measure functions in addition to filters #112

Merged
merged 8 commits into from
Jan 29, 2025
22 changes: 22 additions & 0 deletions expr/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,28 @@
return a.declaration.Intermediate()
}

func (a *AggregateFunction) Clone() *AggregateFunction {
newA := *a
if a.args != nil {
newA.args = make([]types.FuncArg, len(a.args))
copy(newA.args, a.args)
}
if a.options != nil {
newA.options = make([]*types.FunctionOption, len(a.options))
copy(newA.options, a.options)
}
if a.Sorts != nil {
newA.Sorts = make([]SortField, len(a.Sorts))
copy(newA.Sorts, a.Sorts)
}
return &newA

Check warning on line 828 in expr/functions.go

View check run for this annotation

Codecov / codecov/patch

expr/functions.go#L814-L828

Added lines #L814 - L828 were not covered by tests
}

// SetArg sets the specified argument to the provided value. The index is not checked for validity.
func (a *AggregateFunction) SetArg(i int, arg types.FuncArg) {
a.args[i] = arg

Check warning on line 833 in expr/functions.go

View check run for this annotation

Codecov / codecov/patch

expr/functions.go#L832-L833

Added lines #L832 - L833 were not covered by tests
}

func (a *AggregateFunction) String() string {
var b strings.Builder

Expand Down
6 changes: 6 additions & 0 deletions expr/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ type ExtensionRegistry struct {
c *extensions.Collection
}

// NewExtensionRegistry creates a new registry. If you have an existing plan you can use GetExtensionSet() to
// populate an extensions.Set.
func NewExtensionRegistry(extSet extensions.Set, c *extensions.Collection) ExtensionRegistry {
if c == nil {
panic("cannot create registry with nil collection")
}
return ExtensionRegistry{Set: extSet, c: c}
}

// NewEmptyExtensionRegistry creates an empty registry useful starting from scratch.
func NewEmptyExtensionRegistry(c *extensions.Collection) ExtensionRegistry {
return NewExtensionRegistry(extensions.NewSet(), c)
}
Expand All @@ -28,14 +31,17 @@ func (e *ExtensionRegistry) LookupType(anchor uint32) (extensions.Type, bool) {
return e.Set.LookupType(anchor, e.c)
}

// LookupScalarFunction returns a ScalarFunctionVariant associated with a previously used function's anchor.
func (e *ExtensionRegistry) LookupScalarFunction(anchor uint32) (*extensions.ScalarFunctionVariant, bool) {
return e.Set.LookupScalarFunction(anchor, e.c)
}

// LookupAggregateFunction returns an AggregateFunctionVariant associated with a previously used function's anchor.
func (e *ExtensionRegistry) LookupAggregateFunction(anchor uint32) (*extensions.AggregateFunctionVariant, bool) {
return e.Set.LookupAggregateFunction(anchor, e.c)
}

// LookupWindowFunction returns a WindowFunctionVariant associated with a previously used function's anchor.
func (e *ExtensionRegistry) LookupWindowFunction(anchor uint32) (*extensions.WindowFunctionVariant, bool) {
return e.Set.LookupWindowFunction(anchor, e.c)
}
30 changes: 28 additions & 2 deletions plan/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,30 @@
return &aggregate, nil
}

func (ar *AggregateRel) rewriteAggregateFunc(rewriteFunc RewriteFunc, f *expr.AggregateFunction) (*expr.AggregateFunction, error) {
if f == nil {
return f, nil
}

Check warning on line 1174 in plan/relations.go

View check run for this annotation

Codecov / codecov/patch

plan/relations.go#L1173-L1174

Added lines #L1173 - L1174 were not covered by tests
newF := f.Clone()
argsAreEqual := true
for i := 0; i < f.NArgs(); i++ {
arg := f.Arg(i)
if exp, ok := arg.(expr.Expression); ok {
var newExp expr.Expression
var err error
if newExp, err = rewriteFunc(exp); err != nil {
return nil, err
}

Check warning on line 1184 in plan/relations.go

View check run for this annotation

Codecov / codecov/patch

plan/relations.go#L1183-L1184

Added lines #L1183 - L1184 were not covered by tests
newF.SetArg(i, newExp)
argsAreEqual = argsAreEqual && exp == newExp
}
}
if argsAreEqual {
return f, nil
}
return newF, nil
}

func (ar *AggregateRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs ...Rel) (Rel, error) {
if len(newInputs) != 1 {
return nil, substraitgo.ErrInvalidInputCount
Expand All @@ -1187,8 +1211,10 @@
if newMeasures[i].filter, err = rewriteFunc(m.filter); err != nil {
return nil, err
}
measuresAreEqual = measuresAreEqual && newMeasures[i].filter == m.filter
newMeasures[i].measure = m.measure
if newMeasures[i].measure, err = ar.rewriteAggregateFunc(rewriteFunc, m.measure); err != nil {
return nil, err
}

Check warning on line 1216 in plan/relations.go

View check run for this annotation

Codecov / codecov/patch

plan/relations.go#L1215-L1216

Added lines #L1215 - L1216 were not covered by tests
measuresAreEqual = measuresAreEqual && newMeasures[i].filter == m.filter && newMeasures[i].measure == m.measure
}
if groupsAreEqual && measuresAreEqual && newInputs[0] == ar.input {
return ar, nil
Expand Down
52 changes: 39 additions & 13 deletions plan/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/v3/expr"
"github.com/substrait-io/substrait-go/v3/extensions"
"github.com/substrait-io/substrait-go/v3/proto"
"github.com/substrait-io/substrait-go/v3/types"
)
Expand All @@ -26,8 +28,24 @@ func createPrimitiveBool(value bool) expr.Expression {
}

func TestRelations_Copy(t *testing.T) {
aggregateRel := &AggregateRel{input: createVirtualTableReadRel(1), groupingExpressions: []expr.Expression{createPrimitiveFloat(1.0)}, groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{filter: expr.NewPrimitiveLiteral(false, false)}}}
extReg := expr.NewExtensionRegistry(extensions.NewSet(), &extensions.DefaultCollection)
aggregateFnID := extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
Name: "avg",
}
aggregateFn, err := expr.NewAggregateFunc(extReg,
aggregateFnID, nil, types.AggInvocationAll,
types.AggPhaseInitialToResult, nil, createPrimitiveFloat(1.0))
require.NoError(t, err)
aggregateFnRevised, err := expr.NewAggregateFunc(extReg,
aggregateFnID, nil, types.AggInvocationAll,
types.AggPhaseInitialToResult, nil, createPrimitiveFloat(9.0))
require.NoError(t, err)

aggregateRel := &AggregateRel{input: createVirtualTableReadRel(1),
groupingExpressions: []expr.Expression{createPrimitiveFloat(1.0)},
groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{measure: aggregateFn, filter: expr.NewPrimitiveLiteral(false, false)}}}
crossRel := &CrossRel{left: createVirtualTableReadRel(1), right: createVirtualTableReadRel(2)}
extensionLeafRel := &ExtensionLeafRel{}
extensionMultiRel := &ExtensionMultiRel{inputs: []Rel{createVirtualTableReadRel(1), createVirtualTableReadRel(2)}}
Expand Down Expand Up @@ -60,10 +78,13 @@ func TestRelations_Copy(t *testing.T) {
}
testCases := []relationTestCase{
{
name: "AggregateRel Copy with new inputs",
relation: aggregateRel,
newInputs: []Rel{createVirtualTableReadRel(6)},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(6), groupingReferences: aggregateRel.groupingReferences, groupingExpressions: aggregateRel.groupingExpressions, measures: aggregateRel.measures},
name: "AggregateRel Copy with new inputs",
relation: aggregateRel,
newInputs: []Rel{createVirtualTableReadRel(6)},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(6),
groupingReferences: aggregateRel.groupingReferences,
groupingExpressions: aggregateRel.groupingExpressions,
measures: aggregateRel.measures},
},
{
name: "AggregateRel Copy with same inputs and noOpRewrite",
Expand All @@ -73,13 +94,16 @@ func TestRelations_Copy(t *testing.T) {
expectedSameRel: true,
},
{
name: "AggregateRel Copy with new Inputs and noOpReWrite",
relation: aggregateRel,
newInputs: []Rel{createVirtualTableReadRel(7)},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(7), groupingExpressions: aggregateRel.groupingExpressions, groupingReferences: aggregateRel.groupingReferences, measures: aggregateRel.measures},
name: "AggregateRel Copy with new Inputs and noOpReWrite",
relation: aggregateRel,
newInputs: []Rel{createVirtualTableReadRel(7)},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(7),
groupingExpressions: aggregateRel.groupingExpressions,
groupingReferences: aggregateRel.groupingReferences,
measures: aggregateRel.measures},
},
{
name: "AggregateRel Copy with new Inputs and reWriteFunc",
name: "AggregateRel Copy with new Inputs and rewriteFunc",
relation: aggregateRel,
newInputs: []Rel{createVirtualTableReadRel(8)},
rewriteFunc: func(expression expr.Expression) (expr.Expression, error) {
Expand All @@ -91,8 +115,10 @@ func TestRelations_Copy(t *testing.T) {
}
panic("unexpected expression type")
},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(8), groupingExpressions: []expr.Expression{createPrimitiveFloat(9.0)}, groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{filter: expr.NewPrimitiveLiteral(true, false)}}},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(8),
groupingExpressions: []expr.Expression{createPrimitiveFloat(9.0)},
groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{measure: aggregateFnRevised, filter: expr.NewPrimitiveLiteral(true, false)}}},
},
{
name: "ExtensionLeafRel Copy with new inputs",
Expand Down
Loading