From 97de6533cdbaea90d3c430872388925360194713 Mon Sep 17 00:00:00 2001 From: Chandra Sanapala Date: Mon, 16 Dec 2024 11:48:03 +0530 Subject: [PATCH 1/2] feat: get scalar function invocation object from testcase --- testcases/parser/nodes.go | 75 +++++++++++++++++++++++++++++++++- testcases/parser/parse_test.go | 39 +++++++++++++----- 2 files changed, 103 insertions(+), 11 deletions(-) diff --git a/testcases/parser/nodes.go b/testcases/parser/nodes.go index e59c26d..3204a9c 100644 --- a/testcases/parser/nodes.go +++ b/testcases/parser/nodes.go @@ -3,8 +3,10 @@ package parser import ( "fmt" "strconv" + "strings" "github.com/substrait-io/substrait-go/expr" + "github.com/substrait-io/substrait-go/extensions" "github.com/substrait-io/substrait-go/types" ) @@ -43,6 +45,77 @@ type TestCase struct { FuncType TestFuncType } +func (tc *TestCase) GetFunctionOptions() []*types.FunctionOption { + if len(tc.Options) == 0 { + return nil + } + funcOptions := make([]*types.FunctionOption, 0) + for key, value := range tc.Options { + funcOptions = append(funcOptions, &types.FunctionOption{ + Name: key, + Preference: []string{value}, + }) + } + return funcOptions +} + +func (tc *TestCase) scalarSignatureKey() string { + var b strings.Builder + for i, a := range tc.Args { + if i != 0 { + b.WriteByte('_') + } + b.WriteString(a.Type.ShortString()) + } + return b.String() +} + +func (tc *TestCase) aggregateSignatureKey() string { + var b strings.Builder + for i, a := range tc.AggregateArgs { + if i != 0 { + b.WriteByte('_') + } + b.WriteString(a.ColumnType.ShortString()) + } + return b.String() +} + +func (tc *TestCase) signatureKey() string { + if tc.FuncType == ScalarFuncType { + return tc.scalarSignatureKey() + } + return tc.aggregateSignatureKey() +} + +func (tc *TestCase) CompoundFunctionName() string { + return tc.FuncName + ":" + tc.signatureKey() +} + +func (tc *TestCase) ID() extensions.ID { + baseURI := tc.BaseURI + if !strings.HasPrefix(baseURI, "https") || !strings.HasPrefix(baseURI, "http") { + baseURI = "https://github.com/substrait-io/substrait/blob/main" + tc.BaseURI + } + return extensions.ID{ + URI: baseURI, + Name: tc.CompoundFunctionName(), + } +} + +func (tc *TestCase) GetScalarFunctionInvocation(reg *expr.ExtensionRegistry) (*expr.ScalarFunction, error) { + if tc.FuncType != ScalarFuncType { + return nil, fmt.Errorf("not a scalar function testcase") + } + id := tc.ID() + args := make([]types.FuncArg, len(tc.Args)) + for i, arg := range tc.Args { + args[i] = arg.Value + } + + return expr.NewScalarFunc(*reg, id, tc.GetFunctionOptions(), args...) +} + type TestGroup struct { Description string TestCases []*TestCase @@ -69,7 +142,7 @@ func newAggregateArgument(tableName string, columnName string, columnType types. return nil, err } if index < 0 { - return nil, fmt.Errorf("Column index must be greater than or equal to 0") + return nil, fmt.Errorf("column index must be greater than or equal to 0") } return &AggregateArgument{ TableName: tableName, diff --git a/testcases/parser/parse_test.go b/testcases/parser/parse_test.go index a6db999..22d19fa 100644 --- a/testcases/parser/parse_test.go +++ b/testcases/parser/parse_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "github.com/substrait-io/substrait" "github.com/substrait-io/substrait-go/expr" + "github.com/substrait-io/substrait-go/extensions" "github.com/substrait-io/substrait-go/literal" "github.com/substrait-io/substrait-go/types" ) @@ -33,6 +34,19 @@ add(120::i8, 10::i8) [overflow:ERROR] = testFile, err := ParseTestCasesFromString(header + tests) require.NoError(t, err) assert.Len(t, testFile.TestCases, 3) + + arithURI := "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + ids := []string{"add:i8_i8", "add:i16_i16", "add:i8_i8"} + reg := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection) + for i, tc := range testFile.TestCases { + assert.Equal(t, extensions.ID{URI: arithURI, Name: ids[i]}, tc.ID()) + scalarFunc, err1 := tc.GetScalarFunctionInvocation(®) + require.NoError(t, err1) + assert.Equal(t, tc.FuncName, scalarFunc.Name()) + require.Equal(t, 2, scalarFunc.NArgs()) + assert.Equal(t, tc.Args[0].Value, scalarFunc.Arg(0)) + assert.Equal(t, tc.Args[1].Value, scalarFunc.Arg(1)) + } } func TestParseDataTimeExample(t *testing.T) { @@ -245,18 +259,19 @@ add(2::fp64, 2::fp64) [overflow:ERROR, rounding:TIE_TO_EVEN] = 4::fp64` } func TestParseAggregateFunc(t *testing.T) { - header := makeAggregateTestHeader("v1.0", "extensions/functions_arithmetic.yaml") + header := makeAggregateTestHeader("v1.0", "/extensions/functions_arithmetic.yaml") tests := `# basic avg((1,2,3)::fp32) = 2::fp64 sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = ` + arithUri := "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" testFile, err := ParseTestCasesFromString(header + tests) require.NoError(t, err) require.NotNil(t, testFile) assert.Len(t, testFile.TestCases, 2) assert.Equal(t, "avg", testFile.TestCases[0].FuncName) assert.Contains(t, testFile.TestCases[0].GroupDesc, "basic") - assert.Equal(t, testFile.TestCases[0].BaseURI, "extensions/functions_arithmetic.yaml") + assert.Equal(t, testFile.TestCases[0].BaseURI, "/extensions/functions_arithmetic.yaml") assert.Len(t, testFile.TestCases[0].Args, 0) assert.Len(t, testFile.TestCases[0].AggregateArgs, 1) assert.Equal(t, "fp32", testFile.TestCases[0].AggregateArgs[0].ColumnType.String()) @@ -269,16 +284,20 @@ sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = Date: Mon, 16 Dec 2024 23:24:07 +0530 Subject: [PATCH 2/2] address review comments --- testcases/parser/nodes.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/testcases/parser/nodes.go b/testcases/parser/nodes.go index 3204a9c..9a9ba0f 100644 --- a/testcases/parser/nodes.go +++ b/testcases/parser/nodes.go @@ -82,10 +82,14 @@ func (tc *TestCase) aggregateSignatureKey() string { } func (tc *TestCase) signatureKey() string { - if tc.FuncType == ScalarFuncType { + switch tc.FuncType { + case ScalarFuncType: return tc.scalarSignatureKey() + case AggregateFuncType: + return tc.aggregateSignatureKey() + default: + panic(fmt.Sprintf("unsupported function type: %s", tc.FuncType)) } - return tc.aggregateSignatureKey() } func (tc *TestCase) CompoundFunctionName() string { @@ -94,7 +98,7 @@ func (tc *TestCase) CompoundFunctionName() string { func (tc *TestCase) ID() extensions.ID { baseURI := tc.BaseURI - if !strings.HasPrefix(baseURI, "https") || !strings.HasPrefix(baseURI, "http") { + if strings.HasPrefix(baseURI, "/") { baseURI = "https://github.com/substrait-io/substrait/blob/main" + tc.BaseURI } return extensions.ID{