From cb2e0b22fce18e630f099c4079ba20693bfdb94b Mon Sep 17 00:00:00 2001 From: Chandra Sanapala Date: Mon, 16 Dec 2024 05:30:39 +0530 Subject: [PATCH] feat: add func type in testcase --- testcases/parser/nodes.go | 17 ++++++++++++++++- testcases/parser/parse_test.go | 12 +++++++++--- testcases/parser/visitor.go | 35 +++++++++++++++++++++++++--------- 3 files changed, 51 insertions(+), 13 deletions(-) diff --git a/testcases/parser/nodes.go b/testcases/parser/nodes.go index 2bb5322..e59c26d 100644 --- a/testcases/parser/nodes.go +++ b/testcases/parser/nodes.go @@ -8,6 +8,14 @@ import ( "github.com/substrait-io/substrait-go/types" ) +type TestFuncType string + +const ( + ScalarFuncType TestFuncType = "scalar" + AggregateFuncType TestFuncType = "aggregate" + WindowFuncType TestFuncType = "window" +) + type CaseLiteral struct { Type types.Type ValueText string @@ -17,6 +25,7 @@ type CaseLiteral struct { type TestFileHeader struct { Version string + FuncType TestFuncType IncludedURI string } @@ -31,10 +40,16 @@ type TestCase struct { Columns [][]expr.Literal TableName string ColumnTypes []types.Type + FuncType TestFuncType +} + +type TestGroup struct { + Description string + TestCases []*TestCase } type TestFile struct { - Header TestFileHeader + Header *TestFileHeader TestCases []*TestCase } diff --git a/testcases/parser/parse_test.go b/testcases/parser/parse_test.go index ee25758..168d8d5 100644 --- a/testcases/parser/parse_test.go +++ b/testcases/parser/parse_test.go @@ -61,6 +61,7 @@ lt('2016-12-31T13:30:15'::ts, '2017-12-31T13:30:15'::ts) = true::bool timestampType := &types.TimestampType{Nullability: types.NullabilityUnspecified} assert.Equal(t, timestampType, testFile.TestCases[0].Args[0].Type) assert.Equal(t, timestampType, testFile.TestCases[0].Args[1].Type) + assert.Equal(t, ScalarFuncType, testFile.TestCases[0].FuncType) } func TestParseDecimalExample(t *testing.T) { @@ -240,12 +241,14 @@ sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = 0 { + expectedErrorMsg = fmt.Sprintf("Syntax error at line 5:%d: %s", test.position, test.errorMsg) + } assert.Contains(t, err.Error(), expectedErrorMsg) }) } @@ -425,6 +432,7 @@ corr(t1.col0, t2.col1) = 1::fp64`, }, {"((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(my_col::fp32, col0::fp32) = 1::fp64", "mismatched input 'my_col'"}, {"((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(col0::fp32, column1::fp32) = 1::fp64", "mismatched input 'column1'"}, + {"f8('13:01:01.234'::time) = 123::i32", "expected aggregate test case, got scalar"}, } for _, test := range tests { t.Run(test.testCaseStr, func(t *testing.T) { @@ -443,7 +451,6 @@ func TestParseAggregateTestWithVariousTypes(t *testing.T) { {"f1((1, 2, 3, 4)::i64) = 10::fp64"}, {"f1((1, 2, 3, 4)::i16) = 10.0::fp32"}, {"f1((1, 2, 3, 4)::i32) = 10::i64"}, - {"f2(1.0::fp32, 2.0::fp64) = -7.0::fp32"}, {"f3(('a', 'b')::string) = 'c'::str"}, {"f4((false, true)::boolean) = false::bool"}, {"f5((1.1, 2.2)::fp32) = 3.3::fp32"}, @@ -454,7 +461,6 @@ func TestParseAggregateTestWithVariousTypes(t *testing.T) { {"f6((1.1, 2.2, null)::dec?<38,10>) = 3.3::dec<38,10>"}, {"f8(('1991-01-01', '1991-02-02')::date) = '2001-01-01'::date"}, {"f8(('13:01:01.2345678', '14:01:01.333')::time) = 123456::i64"}, - {"f8('13:01:01.234'::time) = 123::i32"}, {"f8(('1991-01-01T01:02:03.456', '1991-01-01T00:00:00')::timestamp) = '1991-01-01T22:33:44'::ts"}, {"f8(('1991-01-01T01:02:03.456+05:30', '1991-01-01T00:00:00+15:30')::tstz) = 23::i32"}, {"f10(('P10Y5M', 'P11Y5M')::interval_year) = 'P21Y10M'::interval_year"}, diff --git a/testcases/parser/visitor.go b/testcases/parser/visitor.go index 204c7ff..2c918f0 100644 --- a/testcases/parser/visitor.go +++ b/testcases/parser/visitor.go @@ -17,6 +17,7 @@ type TestCaseVisitor struct { baseparser.FuncTestCaseParserVisitor ErrorListener util.VisitErrorListener literalTypeInContext types.Type + testFuncType TestFuncType } func (v *TestCaseVisitor) getLiteralTypeInContext() types.Type { @@ -41,7 +42,7 @@ func (v *TestCaseVisitor) Visit(tree antlr.ParseTree) interface{} { } func (v *TestCaseVisitor) VisitDoc(ctx *baseparser.DocContext) interface{} { - header := v.Visit(ctx.Header()).(TestFileHeader) + header := v.Visit(ctx.Header()).(*TestFileHeader) testcases := make([]*TestCase, 0, len(ctx.AllTestGroup())) for _, testGroup := range ctx.AllTestGroup() { groupTestCases := v.Visit(testGroup).([]*TestCase) @@ -57,9 +58,20 @@ func (v *TestCaseVisitor) VisitDoc(ctx *baseparser.DocContext) interface{} { } func (v *TestCaseVisitor) VisitHeader(ctx *baseparser.HeaderContext) interface{} { - return TestFileHeader{ - Version: ctx.Version().GetText(), - IncludedURI: v.Visit(ctx.Include()).(string), + header := v.Visit(ctx.Version()).(*TestFileHeader) + header.IncludedURI = v.Visit(ctx.Include()).(string) + return header +} + +func (v *TestCaseVisitor) VisitVersion(ctx *baseparser.VersionContext) interface{} { + testFuncType := ScalarFuncType + if ctx.SubstraitAggregateTest() != nil { + testFuncType = AggregateFuncType + } + v.testFuncType = testFuncType + return &TestFileHeader{ + Version: ctx.FormatVersion().GetText(), + FuncType: testFuncType, } } @@ -67,17 +79,17 @@ func (v *TestCaseVisitor) VisitInclude(ctx *baseparser.IncludeContext) interface return getRawStringFromStringLiteral(ctx.StringLiteral(0).GetText()) } -type TestGroup struct { - Description string - TestCases []*TestCase -} - func (v *TestCaseVisitor) VisitScalarFuncTestGroup(ctx *baseparser.ScalarFuncTestGroupContext) interface{} { groupDesc := v.Visit(ctx.TestGroupDescription()).(string) groupTestCases := make([]*TestCase, 0, len(ctx.AllTestCase())) + if v.testFuncType != ScalarFuncType { + v.ErrorListener.ReportVisitError(fmt.Errorf("expected %v test case, got scalar", v.testFuncType)) + return groupTestCases + } for _, tc := range ctx.AllTestCase() { testcase := v.Visit(tc).(*TestCase) testcase.GroupDesc = groupDesc + testcase.FuncType = ScalarFuncType groupTestCases = append(groupTestCases, testcase) } return groupTestCases @@ -86,9 +98,14 @@ func (v *TestCaseVisitor) VisitScalarFuncTestGroup(ctx *baseparser.ScalarFuncTes func (v *TestCaseVisitor) VisitAggregateFuncTestGroup(ctx *baseparser.AggregateFuncTestGroupContext) interface{} { groupDesc := v.Visit(ctx.TestGroupDescription()).(string) groupTestCases := make([]*TestCase, 0, len(ctx.AllAggFuncTestCase())) + if v.testFuncType != AggregateFuncType { + v.ErrorListener.ReportVisitError(fmt.Errorf("expected %v test case, got aggregate", v.testFuncType)) + return groupTestCases + } for _, tc := range ctx.AllAggFuncTestCase() { testcase := v.Visit(tc).(*TestCase) testcase.GroupDesc = groupDesc + testcase.FuncType = AggregateFuncType groupTestCases = append(groupTestCases, testcase) } return groupTestCases