Skip to content

Commit

Permalink
feat: unify column data retrieval for aggregate test cases (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
srikrishnak authored Mar 2, 2025
1 parent b0cb727 commit 32ce783
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 0 deletions.
24 changes: 24 additions & 0 deletions testcases/parser/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,30 @@ func (tc *TestCase) GetAggregateFunctionInvocation(reg *expr.ExtensionRegistry,
return nil, fmt.Errorf("%w: no matching function found or %s", substraitgo.ErrNotFound, id)
}

func (tc *TestCase) GetAggregateColumnsData() ([][]expr.Literal, error) {
if tc.FuncType != AggregateFuncType {
return nil, fmt.Errorf("expected function type %v, but got %v", AggregateFuncType, tc.FuncType)
}

if len(tc.Columns) > 0 {
return tc.Columns, nil
}

columns := make([][]expr.Literal, len(tc.AggregateArgs))

for colIdx, arg := range tc.AggregateArgs {
values, ok := arg.Argument.Value.(*expr.NestedLiteral[expr.ListLiteralValue])
if !ok {
return nil, fmt.Errorf("column %d: expected NestedLiteral[ListLiteralValue], but got %T", colIdx, arg.Argument.Value)
}

columns[colIdx] = make([]expr.Literal, len(values.Value))
copy(columns[colIdx], values.Value)
}

return columns, nil
}

type TestGroup struct {
Description string
TestCases []*TestCase
Expand Down
90 changes: 90 additions & 0 deletions testcases/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,96 @@ func TestParseAggregateFuncWithVariousTypes(t *testing.T) {
}
}

func TestParseAggregateFuncAllFormats(t *testing.T) {
header := makeAggregateTestHeader("v1.0", "/extensions/functions_arithmetic.yaml")
header += "# basic\n"

tests := []struct {
testCaseStr string
wantData [][]expr.Literal
}{
{"avg((1,2,3)::i64) = 2::fp64", [][]expr.Literal{newInt64Values(1, 2, 3)}},
{"((1), (2), (3)) avg(col0::i64) = 2::fp64", [][]expr.Literal{newInt64Values(1, 2, 3)}},
{"DEFINE t1(i64) = ((1), (2), (3))\navg(t1.col0) = 2::fp64", [][]expr.Literal{newInt64Values(1, 2, 3)}},

// tests with empty input data
{"avg(()::i64) = 2::fp64", [][]expr.Literal{{}}},
{"DEFINE t1(i64) = ()\navg(t1.col0) = 2::fp64", [][]expr.Literal{{}}},

//tests with multiple columns
{"((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(col0::fp32, col1::fp32?) = 1::fp64?", [][]expr.Literal{newFloat32Values(false, 20, -3, 1, 10, 5), newFloat32Values(true, 20, -3, 1, 10, 5)}},
{"DEFINE t1(fp32, fp32?) = ((20, 20), (-3, -3), (1, 1), (10,10), (5,5))\ncorr(t1.col0, t1.col1) = 1::fp64?", [][]expr.Literal{newFloat32Values(false, 20, -3, 1, 10, 5), newFloat32Values(true, 20, -3, 1, 10, 5)}},
}
for _, test := range tests {
t.Run(test.testCaseStr, func(t *testing.T) {
testFile, err := ParseTestCasesFromString(header + test.testCaseStr)
require.NoError(t, err)
require.NotNil(t, testFile)
assert.Len(t, testFile.TestCases, 1)
tc := testFile.TestCases[0]
assert.Contains(t, test.testCaseStr, tc.FuncName)
assert.Equal(t, tc.GroupDesc, "basic")
assert.Equal(t, tc.BaseURI, "/extensions/functions_arithmetic.yaml")
assert.Len(t, tc.Args, 0)

// check that the types are correct
argTypes := tc.GetArgTypes()
assert.Len(t, argTypes, len(test.wantData))
if len(test.wantData[0]) > 0 {
for i, argType := range argTypes {
assert.Equal(t, argType, test.wantData[i][0].GetType())
}
} else {
// check that the type is correct for empty input data
assert.Equal(t, &types.Int64Type{Nullability: types.NullabilityRequired}, argTypes[0])
}

assert.Equal(t, AggregateFuncType, tc.FuncType)
_, err = tc.GetScalarFunctionInvocation(nil, nil)
require.Error(t, err)

reg := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())
testGetFunctionInvocation(t, tc, &reg, nil)
data, err := tc.GetAggregateColumnsData()
require.NoError(t, err)

// check that the data is correct
assert.Len(t, data, len(test.wantData))
assert.Equal(t, test.wantData, data)
})
}
}

func TestBadInputsToGetAggregateColumnsData(t *testing.T) {
tests := []struct {
name string
testCase *TestCase
expectedError error
}{
{
name: "invalid function type",
testCase: &TestCase{FuncType: ScalarFuncType},
expectedError: fmt.Errorf("expected function type %v, but got %v", AggregateFuncType, ScalarFuncType),
},
{
name: "invalid argument type",
testCase: &TestCase{
FuncType: AggregateFuncType,
AggregateArgs: []*AggregateArgument{{Argument: &CaseLiteral{Value: expr.NewNullLiteral(&types.Float32Type{})}}},
},
expectedError: fmt.Errorf("column 0: expected NestedLiteral[ListLiteralValue], but got %T", expr.NewNullLiteral(&types.Float32Type{})),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := tt.testCase.GetAggregateColumnsData()
assert.Error(t, err)
assert.Equal(t, tt.expectedError.Error(), err.Error())
})
}
}

func TestParseAggregateFuncWithMixedArgs(t *testing.T) {
header := makeAggregateTestHeader("v1.0", "/extensions/functions_arithmetic.yaml")
tests := `# basic
Expand Down

0 comments on commit 32ce783

Please sign in to comment.