diff --git a/distsql/distsql_test.go b/distsql/distsql_test.go index e196015b243e6..a5e5e389acdc3 100644 --- a/distsql/distsql_test.go +++ b/distsql/distsql_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/cznic/mathutil" . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/parser/charset" @@ -34,7 +35,7 @@ import ( "github.com/pingcap/tipb/go-tipb" ) -func (s *testSuite) TestSelectNormal(c *C) { +func (s *testSuite) createSelectNormal(batch, totalRows int, c *C) (*selectResult, []*types.FieldType) { request, err := (&RequestBuilder{}).SetKeyRanges(nil). SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). @@ -67,13 +68,23 @@ func (s *testSuite) TestSelectNormal(c *C) { c.Assert(result.sqlType, Equals, "general") c.Assert(result.rowLen, Equals, len(colTypes)) + resp, ok := result.resp.(*mockResponse) + c.Assert(ok, IsTrue) + resp.total = totalRows + resp.batch = batch + + return result, colTypes +} + +func (s *testSuite) TestSelectNormal(c *C) { + response, colTypes := s.createSelectNormal(1, 2, c) response.Fetch(context.TODO()) // Test Next. chk := chunk.New(colTypes, 32, 32) numAllRows := 0 for { - err = response.Next(context.TODO(), chk) + err := response.Next(context.TODO(), chk) c.Assert(err, IsNil) numAllRows += chk.NumRows() if chk.NumRows() == 0 { @@ -81,11 +92,18 @@ func (s *testSuite) TestSelectNormal(c *C) { } } c.Assert(numAllRows, Equals, 2) - err = response.Close() + err := response.Close() c.Assert(err, IsNil) } -func (s *testSuite) TestSelectStreaming(c *C) { +func (s *testSuite) TestSelectNormalChunkSize(c *C) { + response, colTypes := s.createSelectNormal(100, 1000000, c) + response.Fetch(context.TODO()) + s.testChunkSize(response, colTypes, c) + c.Assert(response.Close(), IsNil) +} + +func (s *testSuite) createSelectStreaming(batch, totalRows int, c *C) (*streamResult, []*types.FieldType) { request, err := (&RequestBuilder{}).SetKeyRanges(nil). SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). @@ -112,20 +130,29 @@ func (s *testSuite) TestSelectStreaming(c *C) { s.sctx.GetSessionVars().EnableStreaming = true - // Test Next. response, err := Select(context.TODO(), s.sctx, request, colTypes, statistics.NewQueryFeedback(0, nil, 0, false)) c.Assert(err, IsNil) result, ok := response.(*streamResult) c.Assert(ok, IsTrue) c.Assert(result.rowLen, Equals, len(colTypes)) + resp, ok := result.resp.(*mockResponse) + c.Assert(ok, IsTrue) + resp.total = totalRows + resp.batch = batch + + return result, colTypes +} + +func (s *testSuite) TestSelectStreaming(c *C) { + response, colTypes := s.createSelectStreaming(1, 2, c) response.Fetch(context.TODO()) // Test Next. chk := chunk.New(colTypes, 32, 32) numAllRows := 0 for { - err = response.Next(context.TODO(), chk) + err := response.Next(context.TODO(), chk) c.Assert(err, IsNil) numAllRows += chk.NumRows() if chk.NumRows() == 0 { @@ -133,10 +160,64 @@ func (s *testSuite) TestSelectStreaming(c *C) { } } c.Assert(numAllRows, Equals, 2) - err = response.Close() + err := response.Close() c.Assert(err, IsNil) } +func (s *testSuite) TestSelectStreamingChunkSize(c *C) { + response, colTypes := s.createSelectStreaming(100, 1000000, c) + response.Fetch(context.TODO()) + s.testChunkSize(response, colTypes, c) + c.Assert(response.Close(), IsNil) +} + +func (s *testSuite) testChunkSize(response SelectResult, colTypes []*types.FieldType, c *C) { + chk := chunk.New(colTypes, 32, 32) + + err := response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.NumRows(), Equals, 32) + + err = response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.NumRows(), Equals, 32) + + chk.SetRequiredRows(1, 32) + err = response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.NumRows(), Equals, 1) + + chk.SetRequiredRows(2, 32) + err = response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.NumRows(), Equals, 2) + + chk.SetRequiredRows(17, 32) + err = response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.NumRows(), Equals, 17) + + chk.SetRequiredRows(170, 32) + err = response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.NumRows(), Equals, 32) + + chk.SetRequiredRows(32, 32) + err = response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.NumRows(), Equals, 32) + + chk.SetRequiredRows(0, 32) + err = response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.NumRows(), Equals, 32) + + chk.SetRequiredRows(-1, 32) + err = response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.NumRows(), Equals, 32) +} + func (s *testSuite) TestAnalyze(c *C) { request, err := (&RequestBuilder{}).SetKeyRanges(nil). SetAnalyzeRequest(&tipb.AnalyzeReq{}). @@ -166,6 +247,8 @@ func (s *testSuite) TestAnalyze(c *C) { // Used only for test. type mockResponse struct { count int + total int + batch int sync.Mutex } @@ -183,17 +266,24 @@ func (resp *mockResponse) Next(ctx context.Context) (kv.ResultSubset, error) { resp.Lock() defer resp.Unlock() - if resp.count == 2 { + if resp.count >= resp.total { return nil, nil } - defer func() { resp.count++ }() + numRows := mathutil.Min(resp.batch, resp.total-resp.count) + resp.count += numRows datum := types.NewIntDatum(1) bytes := make([]byte, 0, 100) bytes, _ = codec.EncodeValue(nil, bytes, datum, datum, datum, datum) + chunks := make([]tipb.Chunk, numRows) + for i := range chunks { + chkData := make([]byte, len(bytes)) + copy(chkData, bytes) + chunks[i] = tipb.Chunk{RowsData: chkData} + } respPB := &tipb.SelectResponse{ - Chunks: []tipb.Chunk{{RowsData: bytes}}, + Chunks: chunks, OutputCounts: []int64{1}, } respBytes, err := respPB.Marshal() diff --git a/distsql/request_builder_test.go b/distsql/request_builder_test.go index 10d319b9c3e65..640127f163594 100644 --- a/distsql/request_builder_test.go +++ b/distsql/request_builder_test.go @@ -53,7 +53,10 @@ func (s *testSuite) SetUpSuite(c *C) { ctx := mock.NewContext() ctx.Store = &mock.Store{ Client: &mock.Client{ - MockResponse: &mockResponse{}, + MockResponse: &mockResponse{ + batch: 1, + total: 2, + }, }, } s.sctx = ctx @@ -67,7 +70,10 @@ func (s *testSuite) SetUpTest(c *C) { ctx := s.sctx.(*mock.Context) store := ctx.Store.(*mock.Store) store.Client = &mock.Client{ - MockResponse: &mockResponse{}, + MockResponse: &mockResponse{ + batch: 1, + total: 2, + }, } } diff --git a/distsql/select_result.go b/distsql/select_result.go index 5badfc624ec1c..ca5b6908d0ac4 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -116,7 +116,7 @@ func (r *selectResult) NextRaw(ctx context.Context) ([]byte, error) { // Next reads data to the chunk. func (r *selectResult) Next(ctx context.Context, chk *chunk.Chunk) error { chk.Reset() - for chk.NumRows() < r.ctx.GetSessionVars().MaxChunkSize { + for !chk.IsFull() { if r.selectResp == nil || r.respChkIdx == len(r.selectResp.Chunks) { err := r.getSelectResp() if err != nil || r.selectResp == nil { @@ -169,9 +169,8 @@ func (r *selectResult) getSelectResp() error { func (r *selectResult) readRowsData(chk *chunk.Chunk) (err error) { rowsData := r.selectResp.Chunks[r.respChkIdx].RowsData - maxChunkSize := r.ctx.GetSessionVars().MaxChunkSize decoder := codec.NewDecoder(chk, r.ctx.GetSessionVars().Location()) - for chk.NumRows() < maxChunkSize && len(rowsData) > 0 { + for !chk.IsFull() && len(rowsData) > 0 { for i := 0; i < r.rowLen; i++ { rowsData, err = decoder.DecodeOne(rowsData, i, r.fieldTypes[i]) if err != nil { diff --git a/distsql/stream.go b/distsql/stream.go index dada7053f7a09..73702f129fed3 100644 --- a/distsql/stream.go +++ b/distsql/stream.go @@ -45,8 +45,7 @@ func (r *streamResult) Fetch(context.Context) {} func (r *streamResult) Next(ctx context.Context, chk *chunk.Chunk) error { chk.Reset() - maxChunkSize := r.ctx.GetSessionVars().MaxChunkSize - for chk.NumRows() < maxChunkSize { + for !chk.IsFull() { err := r.readDataIfNecessary(ctx) if err != nil { return errors.Trace(err) @@ -115,9 +114,8 @@ func (r *streamResult) readDataIfNecessary(ctx context.Context) error { func (r *streamResult) flushToChunk(chk *chunk.Chunk) (err error) { remainRowsData := r.curr.RowsData - maxChunkSize := r.ctx.GetSessionVars().MaxChunkSize decoder := codec.NewDecoder(chk, r.ctx.GetSessionVars().Location()) - for chk.NumRows() < maxChunkSize && len(remainRowsData) > 0 { + for !chk.IsFull() && len(remainRowsData) > 0 { for i := 0; i < r.rowLen; i++ { remainRowsData, err = decoder.DecodeOne(remainRowsData, i, r.fieldTypes[i]) if err != nil { diff --git a/util/chunk/chunk.go b/util/chunk/chunk.go index c59afc45cebe4..26c4dc77beb37 100644 --- a/util/chunk/chunk.go +++ b/util/chunk/chunk.go @@ -33,7 +33,11 @@ type Chunk struct { // It is used only when this Chunk doesn't hold any data, i.e. "len(columns)==0". numVirtualRows int // capacity indicates the max number of rows this chunk can hold. + // TODO: replace all usages of capacity to requiredRows and remove this field capacity int + + // requiredRows indicates how many rows the parent executor want. + requiredRows int } // Capacity constants. @@ -63,6 +67,13 @@ func New(fields []*types.FieldType, cap, maxChunkSize int) *Chunk { } } chk.numVirtualRows = 0 + + // set the default value of requiredRows to maxChunkSize to let chk.IsFull() behave + // like how we judge whether a chunk is full now, then the statement + // "chk.NumRows() < maxChunkSize" + // is equal to + // "!chk.IsFull()". + chk.requiredRows = maxChunkSize return chk } @@ -80,6 +91,7 @@ func Renew(chk *Chunk, maxChunkSize int) *Chunk { newChk.columns = renewColumns(chk.columns, newCap) newChk.numVirtualRows = 0 newChk.capacity = newCap + newChk.requiredRows = maxChunkSize return newChk } @@ -133,6 +145,25 @@ func newVarLenColumn(cap int, old *column) *column { } } +// RequiredRows returns how many rows is considered full. +func (c *Chunk) RequiredRows() int { + return c.requiredRows +} + +// SetRequiredRows sets the number of required rows. +func (c *Chunk) SetRequiredRows(requiredRows, maxChunkSize int) *Chunk { + if requiredRows <= 0 || requiredRows > maxChunkSize { + requiredRows = maxChunkSize + } + c.requiredRows = requiredRows + return c +} + +// IsFull returns if this chunk is considered full. +func (c *Chunk) IsFull() bool { + return c.NumRows() >= c.requiredRows +} + // MakeRef makes column in "dstColIdx" reference to column in "srcColIdx". func (c *Chunk) MakeRef(srcColIdx, dstColIdx int) { c.columns[dstColIdx] = c.columns[srcColIdx] @@ -225,6 +256,7 @@ func (c *Chunk) GrowAndReset(maxChunkSize int) { c.capacity = newCap c.columns = renewColumns(c.columns, newCap) c.numVirtualRows = 0 + c.requiredRows = maxChunkSize } // reCalcCapacity calculates the capacity for another Chunk based on the current diff --git a/util/chunk/chunk_test.go b/util/chunk/chunk_test.go index c19059bcf0143..c6d478f4cb5a9 100644 --- a/util/chunk/chunk_test.go +++ b/util/chunk/chunk_test.go @@ -24,6 +24,7 @@ import ( "time" "unsafe" + "github.com/cznic/mathutil" "github.com/pingcap/check" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx/stmtctx" @@ -245,6 +246,59 @@ func (s *testChunkSuite) TestTruncateTo(c *check.C) { c.Assert(chk.GetRow(1).IsNull(0), check.IsTrue) } +func (s *testChunkSuite) TestChunkSizeControl(c *check.C) { + maxChunkSize := 10 + chk := New([]*types.FieldType{types.NewFieldType(mysql.TypeLong)}, maxChunkSize, maxChunkSize) + c.Assert(chk.RequiredRows(), check.Equals, maxChunkSize) + + for i := 0; i < maxChunkSize; i++ { + chk.AppendInt64(0, 1) + } + maxChunkSize += maxChunkSize / 3 + chk.GrowAndReset(maxChunkSize) + c.Assert(chk.RequiredRows(), check.Equals, maxChunkSize) + + maxChunkSize2 := maxChunkSize + maxChunkSize/3 + chk2 := Renew(chk, maxChunkSize2) + c.Assert(chk2.RequiredRows(), check.Equals, maxChunkSize2) + + chk.Reset() + for i := 1; i < maxChunkSize*2; i++ { + chk.SetRequiredRows(i, maxChunkSize) + c.Assert(chk.RequiredRows(), check.Equals, mathutil.Min(maxChunkSize, i)) + } + + chk.SetRequiredRows(1, maxChunkSize). + SetRequiredRows(2, maxChunkSize). + SetRequiredRows(3, maxChunkSize) + c.Assert(chk.RequiredRows(), check.Equals, 3) + + chk.SetRequiredRows(-1, maxChunkSize) + c.Assert(chk.RequiredRows(), check.Equals, maxChunkSize) + + chk.SetRequiredRows(5, maxChunkSize) + chk.AppendInt64(0, 1) + chk.AppendInt64(0, 1) + chk.AppendInt64(0, 1) + chk.AppendInt64(0, 1) + c.Assert(chk.NumRows(), check.Equals, 4) + c.Assert(chk.IsFull(), check.IsFalse) + + chk.AppendInt64(0, 1) + c.Assert(chk.NumRows(), check.Equals, 5) + c.Assert(chk.IsFull(), check.IsTrue) + + chk.AppendInt64(0, 1) + chk.AppendInt64(0, 1) + chk.AppendInt64(0, 1) + c.Assert(chk.NumRows(), check.Equals, 8) + c.Assert(chk.IsFull(), check.IsTrue) + + chk.SetRequiredRows(maxChunkSize, maxChunkSize) + c.Assert(chk.NumRows(), check.Equals, 8) + c.Assert(chk.IsFull(), check.IsFalse) +} + // newChunk creates a new chunk and initialize columns with element length. // 0 adds an varlen column, positive len add a fixed length column, negative len adds a interface column. func newChunk(elemLen ...int) *Chunk { diff --git a/util/chunk/recordbatch.go b/util/chunk/recordbatch.go index 7eb79f54f4333..3adebec898f8c 100644 --- a/util/chunk/recordbatch.go +++ b/util/chunk/recordbatch.go @@ -14,6 +14,7 @@ package chunk // RecordBatch is input parameter of Executor.Next` method. +// TODO: remove RecordBatch after finishing chunk size control. type RecordBatch struct { *Chunk }