Skip to content

Commit

Permalink
planner/core: migrate test-infra to testify for rule_ tests (pingcap#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tisonkun authored Feb 9, 2022
1 parent 01fdb60 commit 297455d
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 194 deletions.
7 changes: 6 additions & 1 deletion planner/core/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
"go.uber.org/goleak"
)

var testDataMap = make(testdata.BookKeeper, 4)
var testDataMap = make(testdata.BookKeeper, 5)
var indexMergeSuiteData testdata.TestData

func TestMain(m *testing.M) {
Expand All @@ -36,6 +36,7 @@ func TestMain(m *testing.M) {
testDataMap.LoadTestSuiteData("testdata", "index_merge_suite")
testDataMap.LoadTestSuiteData("testdata", "plan_normalized_suite")
testDataMap.LoadTestSuiteData("testdata", "stats_suite")
testDataMap.LoadTestSuiteData("testdata", "ordered_result_mode_suite")

indexMergeSuiteData = testDataMap["index_merge_suite"]

Expand Down Expand Up @@ -63,3 +64,7 @@ func GetPlanNormalizedSuiteData() testdata.TestData {
func GetStatsSuiteData() testdata.TestData {
return testDataMap["stats_suite"]
}

func GetOrderedResultModeSuiteData() testdata.TestData {
return testDataMap["ordered_result_mode_suite"]
}
17 changes: 7 additions & 10 deletions planner/core/rule_inject_extra_projection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@
package core

import (
. "github.com/pingcap/check"
"testing"

"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/mock"
"github.com/stretchr/testify/require"
)

var _ = Suite(&testInjectProjSuite{})

type testInjectProjSuite struct {
}

func (s *testInjectProjSuite) TestWrapCastForAggFuncs(c *C) {
func TestWrapCastForAggFuncs(t *testing.T) {
aggNames := []string{ast.AggFuncSum}
modes := []aggregation.AggFunctionMode{aggregation.CompleteMode,
aggregation.FinalMode, aggregation.Partial1Mode, aggregation.Partial1Mode}
Expand All @@ -45,7 +42,7 @@ func (s *testInjectProjSuite) TestWrapCastForAggFuncs(c *C) {
aggFunc, err := aggregation.NewAggFuncDesc(sctx, name,
[]expression.Expression{&expression.Constant{Value: types.Datum{}, RetType: types.NewFieldType(retType)}},
hasDistinct)
c.Assert(err, IsNil)
require.NoError(t, err)
aggFunc.Mode = mode
aggFuncs = append(aggFuncs, aggFunc)
}
Expand All @@ -61,9 +58,9 @@ func (s *testInjectProjSuite) TestWrapCastForAggFuncs(c *C) {
wrapCastForAggFuncs(mock.NewContext(), aggFuncs)
for i := range aggFuncs {
if aggFuncs[i].Mode != aggregation.FinalMode && aggFuncs[i].Mode != aggregation.Partial2Mode {
c.Assert(aggFuncs[i].RetTp.Tp, Equals, aggFuncs[i].Args[0].GetType().Tp)
require.Equal(t, aggFuncs[i].Args[0].GetType().Tp, aggFuncs[i].RetTp.Tp)
} else {
c.Assert(aggFuncs[i].Args[0].GetType().Tp, Equals, orgAggFuncs[i].Args[0].GetType().Tp)
require.Equal(t, orgAggFuncs[i].Args[0].GetType().Tp, aggFuncs[i].Args[0].GetType().Tp)
}
}
}
200 changes: 99 additions & 101 deletions planner/core/rule_join_reorder_dp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,18 @@ package core

import (
"fmt"
"testing"

. "github.com/pingcap/check"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/planner/property"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/stretchr/testify/require"
)

var _ = Suite(&testJoinReorderDPSuite{})

type testJoinReorderDPSuite struct {
ctx sessionctx.Context
statsMap map[int]*property.StatsInfo
}

func (s *testJoinReorderDPSuite) SetUpTest(c *C) {
s.ctx = MockContext()
s.ctx.GetSessionVars().PlanID = -1
}

type mockLogicalJoin struct {
logicalSchemaProducer
involvedNodeSet int
Expand All @@ -57,31 +46,27 @@ func (mj *mockLogicalJoin) recursiveDeriveStats(_ [][]*expression.Column) (*prop
return mj.statsMap[mj.involvedNodeSet], nil
}

func (s *testJoinReorderDPSuite) newMockJoin(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction, _ []expression.Expression) LogicalPlan {
retJoin := mockLogicalJoin{}.init(s.ctx)
retJoin.schema = expression.MergeSchema(lChild.Schema(), rChild.Schema())
retJoin.statsMap = s.statsMap
if mj, ok := lChild.(*mockLogicalJoin); ok {
retJoin.involvedNodeSet = mj.involvedNodeSet
} else {
retJoin.involvedNodeSet = 1 << uint(lChild.ID())
}
if mj, ok := rChild.(*mockLogicalJoin); ok {
retJoin.involvedNodeSet |= mj.involvedNodeSet
} else {
retJoin.involvedNodeSet |= 1 << uint(rChild.ID())
}
retJoin.SetChildren(lChild, rChild)
return retJoin
}

func (s *testJoinReorderDPSuite) mockStatsInfo(state int, count float64) {
s.statsMap[state] = &property.StatsInfo{
RowCount: count,
func newMockJoin(ctx sessionctx.Context, statsMap map[int]*property.StatsInfo) func(lChild, rChild LogicalPlan, _ []*expression.ScalarFunction, _ []expression.Expression) LogicalPlan {
return func(lChild, rChild LogicalPlan, _ []*expression.ScalarFunction, _ []expression.Expression) LogicalPlan {
retJoin := mockLogicalJoin{}.init(ctx)
retJoin.schema = expression.MergeSchema(lChild.Schema(), rChild.Schema())
retJoin.statsMap = statsMap
if mj, ok := lChild.(*mockLogicalJoin); ok {
retJoin.involvedNodeSet = mj.involvedNodeSet
} else {
retJoin.involvedNodeSet = 1 << uint(lChild.ID())
}
if mj, ok := rChild.(*mockLogicalJoin); ok {
retJoin.involvedNodeSet |= mj.involvedNodeSet
} else {
retJoin.involvedNodeSet |= 1 << uint(rChild.ID())
}
retJoin.SetChildren(lChild, rChild)
return retJoin
}
}

func (s *testJoinReorderDPSuite) makeStatsMapForTPCHQ5() {
func makeStatsMapForTPCHQ5() map[int]*property.StatsInfo {
// Labeled as lineitem -> 0, orders -> 1, customer -> 2, supplier 3, nation 4, region 5
// This graph can be shown as following:
// +---------------+ +---------------+
Expand Down Expand Up @@ -112,48 +97,48 @@ func (s *testJoinReorderDPSuite) makeStatsMapForTPCHQ5() {
// | region |
// | |
// +---------------+
s.statsMap = make(map[int]*property.StatsInfo)
s.mockStatsInfo(3, 9103367)
s.mockStatsInfo(6, 2275919)
s.mockStatsInfo(7, 9103367)
s.mockStatsInfo(9, 59986052)
s.mockStatsInfo(11, 9103367)
s.mockStatsInfo(12, 5999974575)
s.mockStatsInfo(13, 59999974575)
s.mockStatsInfo(14, 9103543072)
s.mockStatsInfo(15, 99103543072)
s.mockStatsInfo(20, 1500000)
s.mockStatsInfo(22, 2275919)
s.mockStatsInfo(23, 7982159)
s.mockStatsInfo(24, 100000)
s.mockStatsInfo(25, 59986052)
s.mockStatsInfo(27, 9103367)
s.mockStatsInfo(28, 5999974575)
s.mockStatsInfo(29, 59999974575)
s.mockStatsInfo(30, 59999974575)
s.mockStatsInfo(31, 59999974575)
s.mockStatsInfo(48, 5)
s.mockStatsInfo(52, 299838)
s.mockStatsInfo(54, 454183)
s.mockStatsInfo(55, 1815222)
s.mockStatsInfo(56, 20042)
s.mockStatsInfo(57, 12022687)
s.mockStatsInfo(59, 1823514)
s.mockStatsInfo(60, 1201884359)
s.mockStatsInfo(61, 12001884359)
s.mockStatsInfo(62, 12001884359)
s.mockStatsInfo(63, 72985)

statsMap := make(map[int]*property.StatsInfo)
statsMap[3] = &property.StatsInfo{RowCount: 9103367}
statsMap[6] = &property.StatsInfo{RowCount: 2275919}
statsMap[7] = &property.StatsInfo{RowCount: 9103367}
statsMap[9] = &property.StatsInfo{RowCount: 59986052}
statsMap[11] = &property.StatsInfo{RowCount: 9103367}
statsMap[12] = &property.StatsInfo{RowCount: 5999974575}
statsMap[13] = &property.StatsInfo{RowCount: 59999974575}
statsMap[14] = &property.StatsInfo{RowCount: 9103543072}
statsMap[15] = &property.StatsInfo{RowCount: 99103543072}
statsMap[20] = &property.StatsInfo{RowCount: 1500000}
statsMap[22] = &property.StatsInfo{RowCount: 2275919}
statsMap[23] = &property.StatsInfo{RowCount: 7982159}
statsMap[24] = &property.StatsInfo{RowCount: 100000}
statsMap[25] = &property.StatsInfo{RowCount: 59986052}
statsMap[27] = &property.StatsInfo{RowCount: 9103367}
statsMap[28] = &property.StatsInfo{RowCount: 5999974575}
statsMap[29] = &property.StatsInfo{RowCount: 59999974575}
statsMap[30] = &property.StatsInfo{RowCount: 59999974575}
statsMap[31] = &property.StatsInfo{RowCount: 59999974575}
statsMap[48] = &property.StatsInfo{RowCount: 5}
statsMap[52] = &property.StatsInfo{RowCount: 299838}
statsMap[54] = &property.StatsInfo{RowCount: 454183}
statsMap[55] = &property.StatsInfo{RowCount: 1815222}
statsMap[56] = &property.StatsInfo{RowCount: 20042}
statsMap[57] = &property.StatsInfo{RowCount: 12022687}
statsMap[59] = &property.StatsInfo{RowCount: 1823514}
statsMap[60] = &property.StatsInfo{RowCount: 1201884359}
statsMap[61] = &property.StatsInfo{RowCount: 12001884359}
statsMap[62] = &property.StatsInfo{RowCount: 12001884359}
statsMap[63] = &property.StatsInfo{RowCount: 72985}
return statsMap
}

func (s *testJoinReorderDPSuite) newDataSource(name string, count int) LogicalPlan {
ds := DataSource{}.Init(s.ctx, 0)
func newDataSource(ctx sessionctx.Context, name string, count int) LogicalPlan {
ds := DataSource{}.Init(ctx, 0)
tan := model.NewCIStr(name)
ds.TableAsName = &tan
ds.schema = expression.NewSchema()
s.ctx.GetSessionVars().PlanColumnID++
ctx.GetSessionVars().PlanColumnID++
ds.schema.Append(&expression.Column{
UniqueID: s.ctx.GetSessionVars().PlanColumnID,
UniqueID: ctx.GetSessionVars().PlanColumnID,
RetType: types.NewFieldType(mysql.TypeLonglong),
})
ds.stats = &property.StatsInfo{
Expand All @@ -162,57 +147,70 @@ func (s *testJoinReorderDPSuite) newDataSource(name string, count int) LogicalPl
return ds
}

func (s *testJoinReorderDPSuite) planToString(plan LogicalPlan) string {
func planToString(plan LogicalPlan) string {
switch x := plan.(type) {
case *mockLogicalJoin:
return fmt.Sprintf("MockJoin{%v, %v}", s.planToString(x.children[0]), s.planToString(x.children[1]))
return fmt.Sprintf("MockJoin{%v, %v}", planToString(x.children[0]), planToString(x.children[1]))
case *DataSource:
return x.TableAsName.L
}
return ""
}

func (s *testJoinReorderDPSuite) TestDPReorderTPCHQ5(c *C) {
s.makeStatsMapForTPCHQ5()
func TestDPReorderTPCHQ5(t *testing.T) {
statsMap := makeStatsMapForTPCHQ5()

ctx := MockContext()
ctx.GetSessionVars().PlanID = -1
joinGroups := make([]LogicalPlan, 0, 6)
joinGroups = append(joinGroups, s.newDataSource("lineitem", 59986052))
joinGroups = append(joinGroups, s.newDataSource("orders", 15000000))
joinGroups = append(joinGroups, s.newDataSource("customer", 1500000))
joinGroups = append(joinGroups, s.newDataSource("supplier", 100000))
joinGroups = append(joinGroups, s.newDataSource("nation", 25))
joinGroups = append(joinGroups, s.newDataSource("region", 5))
joinGroups = append(joinGroups, newDataSource(ctx, "lineitem", 59986052))
joinGroups = append(joinGroups, newDataSource(ctx, "orders", 15000000))
joinGroups = append(joinGroups, newDataSource(ctx, "customer", 1500000))
joinGroups = append(joinGroups, newDataSource(ctx, "supplier", 100000))
joinGroups = append(joinGroups, newDataSource(ctx, "nation", 25))
joinGroups = append(joinGroups, newDataSource(ctx, "region", 5))

var eqConds []expression.Expression
eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[0].Schema().Columns[0], joinGroups[1].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[1].Schema().Columns[0], joinGroups[2].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[2].Schema().Columns[0], joinGroups[3].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[0].Schema().Columns[0], joinGroups[3].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[2].Schema().Columns[0], joinGroups[4].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[3].Schema().Columns[0], joinGroups[4].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[4].Schema().Columns[0], joinGroups[5].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[0].Schema().Columns[0], joinGroups[1].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[1].Schema().Columns[0], joinGroups[2].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[2].Schema().Columns[0], joinGroups[3].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[0].Schema().Columns[0], joinGroups[3].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[2].Schema().Columns[0], joinGroups[4].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[3].Schema().Columns[0], joinGroups[4].Schema().Columns[0]))
eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[4].Schema().Columns[0], joinGroups[5].Schema().Columns[0]))
solver := &joinReorderDPSolver{
baseSingleGroupJoinOrderSolver: &baseSingleGroupJoinOrderSolver{
ctx: s.ctx,
ctx: ctx,
},
newJoin: s.newMockJoin,
newJoin: newMockJoin(ctx, statsMap),
}
result, err := solver.solve(joinGroups, eqConds, nil)
c.Assert(err, IsNil)
c.Assert(s.planToString(result), Equals, "MockJoin{supplier, MockJoin{lineitem, MockJoin{orders, MockJoin{customer, MockJoin{nation, region}}}}}")
require.NoError(t, err)

expected := "MockJoin{supplier, MockJoin{lineitem, MockJoin{orders, MockJoin{customer, MockJoin{nation, region}}}}}"
require.Equal(t, expected, planToString(result))
}

func (s *testJoinReorderDPSuite) TestDPReorderAllCartesian(c *C) {
func TestDPReorderAllCartesian(t *testing.T) {
statsMap := makeStatsMapForTPCHQ5()

ctx := MockContext()
ctx.GetSessionVars().PlanID = -1

joinGroup := make([]LogicalPlan, 0, 4)
joinGroup = append(joinGroup, s.newDataSource("a", 100))
joinGroup = append(joinGroup, s.newDataSource("b", 100))
joinGroup = append(joinGroup, s.newDataSource("c", 100))
joinGroup = append(joinGroup, s.newDataSource("d", 100))
joinGroup = append(joinGroup, newDataSource(ctx, "a", 100))
joinGroup = append(joinGroup, newDataSource(ctx, "b", 100))
joinGroup = append(joinGroup, newDataSource(ctx, "c", 100))
joinGroup = append(joinGroup, newDataSource(ctx, "d", 100))
solver := &joinReorderDPSolver{
baseSingleGroupJoinOrderSolver: &baseSingleGroupJoinOrderSolver{
ctx: s.ctx,
ctx: ctx,
},
newJoin: s.newMockJoin,
newJoin: newMockJoin(ctx, statsMap),
}
result, err := solver.solve(joinGroup, nil, nil)
c.Assert(err, IsNil)
c.Assert(s.planToString(result), Equals, "MockJoin{MockJoin{a, b}, MockJoin{c, d}}")
require.NoError(t, err)

expected := "MockJoin{MockJoin{a, b}, MockJoin{c, d}}"
require.Equal(t, expected, planToString(result))
}
Loading

0 comments on commit 297455d

Please sign in to comment.