From 98e8cec7edb0a92991e06fe003f72a6390b1ef61 Mon Sep 17 00:00:00 2001
From: tangenta <tangenta@126.com>
Date: Fri, 12 Apr 2024 15:42:23 +0800
Subject: [PATCH] executor: make tablesample work under different partition
 prune modes (#52405)

close pingcap/tidb#52282
---
 pkg/executor/builder.go                       |  2 +-
 pkg/executor/sample.go                        | 80 ++++++++++++-------
 pkg/planner/core/find_best_task.go            |  1 +
 pkg/planner/core/logical_plan_builder.go      |  2 +-
 pkg/planner/core/physical_plans.go            |  3 +-
 .../integrationtest/r/executor/sample.result  | 14 ++++
 tests/integrationtest/t/executor/sample.test  | 10 +++
 7 files changed, 79 insertions(+), 33 deletions(-)

diff --git a/pkg/executor/builder.go b/pkg/executor/builder.go
index 6134ea38ee071..f6584ec80342b 100644
--- a/pkg/executor/builder.go
+++ b/pkg/executor/builder.go
@@ -5161,7 +5161,7 @@ func (b *executorBuilder) buildTableSample(v *plannercore.PhysicalTableSample) *
 		e.sampler = &emptySampler{}
 	} else if v.TableSampleInfo.AstNode.SampleMethod == ast.SampleMethodTypeTiDBRegion {
 		e.sampler = newTableRegionSampler(
-			b.ctx, v.TableInfo, startTS, v.TableSampleInfo.Partitions, v.Schema(),
+			b.ctx, v.TableInfo, startTS, v.PhysicalTableID, v.TableSampleInfo.Partitions, v.Schema(),
 			v.TableSampleInfo.FullSchema, e.RetFieldTypes(), v.Desc)
 	}
 
diff --git a/pkg/executor/sample.go b/pkg/executor/sample.go
index 2b1213129ca54..2c9d07ec0e60b 100644
--- a/pkg/executor/sample.go
+++ b/pkg/executor/sample.go
@@ -75,10 +75,12 @@ type rowSampler interface {
 }
 
 type tableRegionSampler struct {
-	ctx        sessionctx.Context
-	table      table.Table
-	startTS    uint64
-	partTables []table.PartitionedTable
+	ctx             sessionctx.Context
+	table           table.Table
+	startTS         uint64
+	physicalTableID int64
+	partTables      []table.PartitionedTable
+
 	schema     *expression.Schema
 	fullSchema *expression.Schema
 	isDesc     bool
@@ -89,18 +91,28 @@ type tableRegionSampler struct {
 	isFinished   bool
 }
 
-func newTableRegionSampler(ctx sessionctx.Context, t table.Table, startTs uint64, partTables []table.PartitionedTable,
-	schema *expression.Schema, fullSchema *expression.Schema, retTypes []*types.FieldType, desc bool) *tableRegionSampler {
+func newTableRegionSampler(
+	ctx sessionctx.Context,
+	t table.Table,
+	startTs uint64,
+	pyhsicalTableID int64,
+	partTables []table.PartitionedTable,
+	schema *expression.Schema,
+	fullSchema *expression.Schema,
+	retTypes []*types.FieldType,
+	desc bool,
+) *tableRegionSampler {
 	return &tableRegionSampler{
-		ctx:        ctx,
-		table:      t,
-		startTS:    startTs,
-		partTables: partTables,
-		schema:     schema,
-		fullSchema: fullSchema,
-		isDesc:     desc,
-		retTypes:   retTypes,
-		rowMap:     make(map[int64]types.Datum),
+		ctx:             ctx,
+		table:           t,
+		startTS:         startTs,
+		partTables:      partTables,
+		physicalTableID: pyhsicalTableID,
+		schema:          schema,
+		fullSchema:      fullSchema,
+		isDesc:          desc,
+		retTypes:        retTypes,
+		rowMap:          make(map[int64]types.Datum),
 	}
 }
 
@@ -176,23 +188,31 @@ func (s *tableRegionSampler) writeChunkFromRanges(ranges []kv.KeyRange, req *chu
 }
 
 func (s *tableRegionSampler) splitTableRanges() ([]kv.KeyRange, error) {
-	if len(s.partTables) != 0 {
-		var ranges []kv.KeyRange
-		for _, t := range s.partTables {
-			for _, pid := range t.GetAllPartitionIDs() {
-				start := tablecodec.GenTableRecordPrefix(pid)
-				end := start.PrefixNext()
-				rs, err := splitIntoMultiRanges(s.ctx.GetStore(), start, end)
-				if err != nil {
-					return nil, err
-				}
-				ranges = append(ranges, rs...)
-			}
+	partitionTable := s.table.GetPartitionedTable()
+	if partitionTable == nil {
+		startKey, endKey := s.table.RecordPrefix(), s.table.RecordPrefix().PrefixNext()
+		return splitIntoMultiRanges(s.ctx.GetStore(), startKey, endKey)
+	}
+
+	var partIDs []int64
+	if partitionTable.Meta().ID == s.physicalTableID {
+		for _, p := range s.partTables {
+			partIDs = append(partIDs, p.GetAllPartitionIDs()...)
 		}
-		return ranges, nil
+	} else {
+		partIDs = []int64{s.physicalTableID}
 	}
-	startKey, endKey := s.table.RecordPrefix(), s.table.RecordPrefix().PrefixNext()
-	return splitIntoMultiRanges(s.ctx.GetStore(), startKey, endKey)
+	ranges := make([]kv.KeyRange, 0, len(partIDs))
+	for _, pid := range partIDs {
+		start := tablecodec.GenTableRecordPrefix(pid)
+		end := start.PrefixNext()
+		rs, err := splitIntoMultiRanges(s.ctx.GetStore(), start, end)
+		if err != nil {
+			return nil, err
+		}
+		ranges = append(ranges, rs...)
+	}
+	return ranges, nil
 }
 
 func splitIntoMultiRanges(store kv.Storage, startKey, endKey kv.Key) ([]kv.KeyRange, error) {
diff --git a/pkg/planner/core/find_best_task.go b/pkg/planner/core/find_best_task.go
index ed50d9cc1b8b3..dccfae4f463db 100644
--- a/pkg/planner/core/find_best_task.go
+++ b/pkg/planner/core/find_best_task.go
@@ -2535,6 +2535,7 @@ func (ds *DataSource) convertToSampleTable(prop *property.PhysicalProperty,
 	p := PhysicalTableSample{
 		TableSampleInfo: ds.SampleInfo,
 		TableInfo:       ds.table,
+		PhysicalTableID: ds.physicalTableID,
 		Desc:            candidate.isMatchProp && prop.SortItems[0].Desc,
 	}.Init(ds.SCtx(), ds.QueryBlockOffset())
 	p.schema = ds.schema
diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go
index 6290239053b69..7c3eefc2692c2 100644
--- a/pkg/planner/core/logical_plan_builder.go
+++ b/pkg/planner/core/logical_plan_builder.go
@@ -5017,7 +5017,7 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, as
 	ds.SetSchema(schema)
 	ds.names = names
 	ds.setPreferredStoreType(b.TableHints())
-	ds.SampleInfo = NewTableSampleInfo(tn.TableSample, schema.Clone(), b.partitionedTable)
+	ds.SampleInfo = NewTableSampleInfo(tn.TableSample, schema, b.partitionedTable)
 	b.isSampling = ds.SampleInfo != nil
 
 	for i, colExpr := range ds.Schema().Columns {
diff --git a/pkg/planner/core/physical_plans.go b/pkg/planner/core/physical_plans.go
index 7aa9b43a6d92d..1be00cb580356 100644
--- a/pkg/planner/core/physical_plans.go
+++ b/pkg/planner/core/physical_plans.go
@@ -2525,6 +2525,7 @@ type PhysicalTableSample struct {
 	physicalSchemaProducer
 	TableSampleInfo *TableSampleInfo
 	TableInfo       table.Table
+	PhysicalTableID int64
 	Desc            bool
 }
 
@@ -2558,7 +2559,7 @@ func NewTableSampleInfo(node *ast.TableSample, fullSchema *expression.Schema, pt
 	}
 	return &TableSampleInfo{
 		AstNode:    node,
-		FullSchema: fullSchema,
+		FullSchema: fullSchema.Clone(),
 		Partitions: pt,
 	}
 }
diff --git a/tests/integrationtest/r/executor/sample.result b/tests/integrationtest/r/executor/sample.result
index b5879afbe145c..0707a4ac14d0b 100644
--- a/tests/integrationtest/r/executor/sample.result
+++ b/tests/integrationtest/r/executor/sample.result
@@ -207,3 +207,17 @@ pk	v
 500	a
 9223372036854775809	b
 set @@global.tidb_scatter_region=default;
+drop table if exists t;
+create table t (a int, b varchar(255), primary key (a)) partition by hash(a) partitions 2;
+insert into t values (1, '1'), (2, '2'), (3, '3');
+set @@tidb_partition_prune_mode='static';
+select * from t tablesample regions() order by a;
+a	b
+1	1
+2	2
+set @@tidb_partition_prune_mode='dynamic';
+select * from t tablesample regions() order by a;
+a	b
+1	1
+2	2
+set @@tidb_partition_prune_mode=default;
diff --git a/tests/integrationtest/t/executor/sample.test b/tests/integrationtest/t/executor/sample.test
index ffaad378bd920..8a237c8539f17 100644
--- a/tests/integrationtest/t/executor/sample.test
+++ b/tests/integrationtest/t/executor/sample.test
@@ -131,3 +131,13 @@ SPLIT TABLE a BY (500);
 SELECT * FROM a TABLESAMPLE REGIONS() ORDER BY pk;
 
 set @@global.tidb_scatter_region=default;
+
+# TestTableSamplePartitionPruneMode
+drop table if exists t;
+create table t (a int, b varchar(255), primary key (a)) partition by hash(a) partitions 2;
+insert into t values (1, '1'), (2, '2'), (3, '3');
+set @@tidb_partition_prune_mode='static';
+select * from t tablesample regions() order by a;
+set @@tidb_partition_prune_mode='dynamic';
+select * from t tablesample regions() order by a;
+set @@tidb_partition_prune_mode=default;