Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: get scalar function invocation object from testcase #91

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion testcases/parser/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package parser
import (
"fmt"
"strconv"
"strings"

"github.com/substrait-io/substrait-go/expr"
"github.com/substrait-io/substrait-go/extensions"
"github.com/substrait-io/substrait-go/types"
)

Expand Down Expand Up @@ -43,6 +45,81 @@ type TestCase struct {
FuncType TestFuncType
}

func (tc *TestCase) GetFunctionOptions() []*types.FunctionOption {
if len(tc.Options) == 0 {
return nil
}
funcOptions := make([]*types.FunctionOption, 0)
for key, value := range tc.Options {
funcOptions = append(funcOptions, &types.FunctionOption{
Name: key,
Preference: []string{value},
})
}
return funcOptions
}

func (tc *TestCase) scalarSignatureKey() string {
var b strings.Builder
for i, a := range tc.Args {
if i != 0 {
b.WriteByte('_')
}
b.WriteString(a.Type.ShortString())
}
return b.String()
}

func (tc *TestCase) aggregateSignatureKey() string {
var b strings.Builder
for i, a := range tc.AggregateArgs {
if i != 0 {
b.WriteByte('_')
}
b.WriteString(a.ColumnType.ShortString())
}
return b.String()
}

func (tc *TestCase) signatureKey() string {
switch tc.FuncType {
case ScalarFuncType:
return tc.scalarSignatureKey()
case AggregateFuncType:
return tc.aggregateSignatureKey()
default:
panic(fmt.Sprintf("unsupported function type: %s", tc.FuncType))
}
}

func (tc *TestCase) CompoundFunctionName() string {
return tc.FuncName + ":" + tc.signatureKey()
}

func (tc *TestCase) ID() extensions.ID {
baseURI := tc.BaseURI
if strings.HasPrefix(baseURI, "/") {
baseURI = "https://github.com/substrait-io/substrait/blob/main" + tc.BaseURI
}
return extensions.ID{
URI: baseURI,
Name: tc.CompoundFunctionName(),
}
}

func (tc *TestCase) GetScalarFunctionInvocation(reg *expr.ExtensionRegistry) (*expr.ScalarFunction, error) {
if tc.FuncType != ScalarFuncType {
return nil, fmt.Errorf("not a scalar function testcase")
}
id := tc.ID()
args := make([]types.FuncArg, len(tc.Args))
for i, arg := range tc.Args {
args[i] = arg.Value
}

return expr.NewScalarFunc(*reg, id, tc.GetFunctionOptions(), args...)
}

type TestGroup struct {
Description string
TestCases []*TestCase
Expand All @@ -69,7 +146,7 @@ func newAggregateArgument(tableName string, columnName string, columnType types.
return nil, err
}
if index < 0 {
return nil, fmt.Errorf("Column index must be greater than or equal to 0")
return nil, fmt.Errorf("column index must be greater than or equal to 0")
}
return &AggregateArgument{
TableName: tableName,
Expand Down
39 changes: 29 additions & 10 deletions testcases/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait"
"github.com/substrait-io/substrait-go/expr"
"github.com/substrait-io/substrait-go/extensions"
"github.com/substrait-io/substrait-go/literal"
"github.com/substrait-io/substrait-go/types"
)
Expand All @@ -33,6 +34,19 @@ add(120::i8, 10::i8) [overflow:ERROR] = <!ERROR>
testFile, err := ParseTestCasesFromString(header + tests)
require.NoError(t, err)
assert.Len(t, testFile.TestCases, 3)

arithURI := "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
ids := []string{"add:i8_i8", "add:i16_i16", "add:i8_i8"}
reg := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection)
for i, tc := range testFile.TestCases {
assert.Equal(t, extensions.ID{URI: arithURI, Name: ids[i]}, tc.ID())
scalarFunc, err1 := tc.GetScalarFunctionInvocation(&reg)
require.NoError(t, err1)
assert.Equal(t, tc.FuncName, scalarFunc.Name())
require.Equal(t, 2, scalarFunc.NArgs())
assert.Equal(t, tc.Args[0].Value, scalarFunc.Arg(0))
assert.Equal(t, tc.Args[1].Value, scalarFunc.Arg(1))
}
}

func TestParseDataTimeExample(t *testing.T) {
Expand Down Expand Up @@ -245,18 +259,19 @@ add(2::fp64, 2::fp64) [overflow:ERROR, rounding:TIE_TO_EVEN] = 4::fp64`
}

func TestParseAggregateFunc(t *testing.T) {
header := makeAggregateTestHeader("v1.0", "extensions/functions_arithmetic.yaml")
header := makeAggregateTestHeader("v1.0", "/extensions/functions_arithmetic.yaml")
tests := `# basic
avg((1,2,3)::fp32) = 2::fp64
sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = <!ERROR>`

arithUri := "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
testFile, err := ParseTestCasesFromString(header + tests)
require.NoError(t, err)
require.NotNil(t, testFile)
assert.Len(t, testFile.TestCases, 2)
assert.Equal(t, "avg", testFile.TestCases[0].FuncName)
assert.Contains(t, testFile.TestCases[0].GroupDesc, "basic")
assert.Equal(t, testFile.TestCases[0].BaseURI, "extensions/functions_arithmetic.yaml")
assert.Equal(t, testFile.TestCases[0].BaseURI, "/extensions/functions_arithmetic.yaml")
assert.Len(t, testFile.TestCases[0].Args, 0)
assert.Len(t, testFile.TestCases[0].AggregateArgs, 1)
assert.Equal(t, "fp32", testFile.TestCases[0].AggregateArgs[0].ColumnType.String())
Expand All @@ -269,16 +284,20 @@ sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = <!ER
assert.Equal(t, "fp64", testFile.TestCases[0].Result.Type.String())
assert.Equal(t, literal.NewFloat64(2), testFile.TestCases[0].Result.Value)
assert.Equal(t, AggregateFuncType, testFile.TestCases[0].FuncType)
assert.Equal(t, extensions.ID{URI: arithUri, Name: "avg:fp32"}, testFile.TestCases[0].ID())
_, err = testFile.TestCases[0].GetScalarFunctionInvocation(nil)
require.Error(t, err)

assert.Equal(t, "sum", testFile.TestCases[1].FuncName)
assert.Contains(t, testFile.TestCases[1].GroupDesc, "basic")
assert.Equal(t, testFile.TestCases[1].BaseURI, "extensions/functions_arithmetic.yaml")
assert.Equal(t, testFile.TestCases[1].BaseURI, "/extensions/functions_arithmetic.yaml")
assert.Len(t, testFile.TestCases[1].Args, 0)
assert.Len(t, testFile.TestCases[1].AggregateArgs, 1)
assert.Equal(t, AggregateFuncType, testFile.TestCases[1].FuncType)
assert.Equal(t, "i64", testFile.TestCases[1].AggregateArgs[0].ColumnType.String())
assert.Equal(t, newInt64List(9223372036854775806, 1, 1, 1, 1, 10000000000), testFile.TestCases[1].AggregateArgs[0].Argument.Value)
assert.Equal(t, "ERROR", testFile.TestCases[1].Options["overflow"])
assert.Equal(t, extensions.ID{URI: arithUri, Name: "sum:i64"}, testFile.TestCases[1].ID())
}

func newInt64List(values ...int64) interface{} {
Expand Down Expand Up @@ -308,7 +327,7 @@ func newFloat32Values(values ...float32) []expr.Literal {
}

func TestParseAggregateFuncCompact(t *testing.T) {
header := makeAggregateTestHeader("v1.0", "extensions/functions_arithmetic.yaml")
header := makeAggregateTestHeader("v1.0", "/extensions/functions_arithmetic.yaml")
tests := `# basic
((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(col0::fp32, col1::fp32) = 1::fp64
`
Expand All @@ -319,7 +338,7 @@ func TestParseAggregateFuncCompact(t *testing.T) {
assert.Len(t, testFile.TestCases, 1)
assert.Equal(t, "corr", testFile.TestCases[0].FuncName)
assert.Contains(t, testFile.TestCases[0].GroupDesc, "basic")
assert.Equal(t, testFile.TestCases[0].BaseURI, "extensions/functions_arithmetic.yaml")
assert.Equal(t, testFile.TestCases[0].BaseURI, "/extensions/functions_arithmetic.yaml")
assert.Len(t, testFile.TestCases[0].Args, 0)
assert.Len(t, testFile.TestCases[0].AggregateArgs, 2)
assert.Equal(t, newFloat32Values(20, -3, 1, 10, 5), testFile.TestCases[0].Columns[0])
Expand All @@ -341,7 +360,7 @@ func createAggregateArg(t *testing.T, tableName, columnName string, columnType t
}

func TestParseAggregateFuncWithMultipleArgs(t *testing.T) {
header := makeAggregateTestHeader("v1.0", "extensions/functions_arithmetic.yaml")
header := makeAggregateTestHeader("v1.0", "/extensions/functions_arithmetic.yaml")
tests := `# basic
DEFINE t1(fp32, fp32) = ((20, 20), (-3, -3), (1, 1), (10,10), (5,5.5))
corr(t1.col0, t1.col1) = 1::fp64
Expand All @@ -355,7 +374,7 @@ corr(t1.col1, t1.col0) = 1::fp64
assert.Len(t, testFile.TestCases, 2)
assert.Equal(t, "corr", testFile.TestCases[0].FuncName)
assert.Contains(t, testFile.TestCases[0].GroupDesc, "basic")
assert.Equal(t, testFile.TestCases[0].BaseURI, "extensions/functions_arithmetic.yaml")
assert.Equal(t, testFile.TestCases[0].BaseURI, "/extensions/functions_arithmetic.yaml")
assert.Len(t, testFile.TestCases[0].Args, 0)
assert.Len(t, testFile.TestCases[0].AggregateArgs, 2)
assert.Equal(t, newFloat32Values(20, -3, 1, 10, 5), testFile.TestCases[0].Columns[0])
Expand All @@ -365,7 +384,7 @@ corr(t1.col1, t1.col0) = 1::fp64

assert.Equal(t, "corr", testFile.TestCases[1].FuncName)
assert.Contains(t, testFile.TestCases[1].GroupDesc, "basic")
assert.Equal(t, testFile.TestCases[1].BaseURI, "extensions/functions_arithmetic.yaml")
assert.Equal(t, testFile.TestCases[1].BaseURI, "/extensions/functions_arithmetic.yaml")
assert.Len(t, testFile.TestCases[1].Args, 0)
assert.Len(t, testFile.TestCases[1].AggregateArgs, 2)
assert.Equal(t, newInt64Values(20, -3, 1, 10, 5), testFile.TestCases[1].Columns[0])
Expand All @@ -375,7 +394,7 @@ corr(t1.col1, t1.col0) = 1::fp64
}

func TestParseAggregateFuncWithVariousTypes(t *testing.T) {
header := makeAggregateTestHeader("v1.0", "extensions/functions_arithmetic.yaml")
header := makeAggregateTestHeader("v1.0", "/extensions/functions_arithmetic.yaml")
header += "# basic\n"

tests := []struct {
Expand All @@ -394,7 +413,7 @@ func TestParseAggregateFuncWithVariousTypes(t *testing.T) {
}

func TestParseAggregateFuncWithMixedArgs(t *testing.T) {
header := makeAggregateTestHeader("v1.0", "extensions/functions_arithmetic.yaml")
header := makeAggregateTestHeader("v1.0", "/extensions/functions_arithmetic.yaml")
tests := `# basic
((20), (-3), (1), (10)) LIST_AGG(col0::fp32, ','::string) = 1::fp64
DEFINE t1(fp32) = ((20), (-3), (1), (10))
Expand Down
Loading