From b2e18d391f6b441a00dd735add4ce3a45f5534db Mon Sep 17 00:00:00 2001
From: Yiding Cui <winoros@gmail.com>
Date: Tue, 14 May 2024 20:08:24 +0800
Subject: [PATCH 1/4] planner: UPDATE's select plan's output should be stable

---
 pkg/expression/util.go                        | 10 ++++----
 pkg/planner/core/logical_plan_builder.go      |  4 ++--
 pkg/planner/core/rule_eliminate_projection.go |  8 +++----
 pkg/sessionctx/stmtctx/stmtctx.go             |  3 ++-
 .../planner/core/issuetest/planner_issue.test | 24 +++++++++++++++++++
 5 files changed, 36 insertions(+), 13 deletions(-)

diff --git a/pkg/expression/util.go b/pkg/expression/util.go
index db4b4c1ee6d75..129ef2418f74a 100644
--- a/pkg/expression/util.go
+++ b/pkg/expression/util.go
@@ -36,9 +36,9 @@ import (
 	driver "github.com/pingcap/tidb/pkg/types/parser_driver"
 	"github.com/pingcap/tidb/pkg/util/chunk"
 	"github.com/pingcap/tidb/pkg/util/collate"
+	"github.com/pingcap/tidb/pkg/util/intset"
 	"github.com/pingcap/tidb/pkg/util/logutil"
 	"go.uber.org/zap"
-	"golang.org/x/tools/container/intsets"
 )
 
 // cowExprRef is a copy-on-write slice ref util using in `ColumnSubstitute`
@@ -372,15 +372,15 @@ func ExtractColumnsAndCorColumnsFromExpressions(result []*Column, list []Express
 }
 
 // ExtractColumnSet extracts the different values of `UniqueId` for columns in expressions.
-func ExtractColumnSet(exprs ...Expression) *intsets.Sparse {
-	set := &intsets.Sparse{}
+func ExtractColumnSet(exprs ...Expression) intset.FastIntSet {
+	set := intset.NewFastIntSet()
 	for _, expr := range exprs {
-		extractColumnSet(expr, set)
+		extractColumnSet(expr, &set)
 	}
 	return set
 }
 
-func extractColumnSet(expr Expression, set *intsets.Sparse) {
+func extractColumnSet(expr Expression, set *intset.FastIntSet) {
 	switch v := expr.(type) {
 	case *Column:
 		set.Insert(int(v.UniqueID))
diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go
index 6f41cfe0ba5bc..cebc2980c0eeb 100644
--- a/pkg/planner/core/logical_plan_builder.go
+++ b/pkg/planner/core/logical_plan_builder.go
@@ -6090,8 +6090,8 @@ func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.Tab
 			allAssignmentsAreConstant = false
 		}
 		p = np
-		if col, ok := newExpr.(*expression.Column); ok {
-			b.ctx.GetSessionVars().StmtCtx.ColRefFromUpdatePlan = append(b.ctx.GetSessionVars().StmtCtx.ColRefFromUpdatePlan, col.UniqueID)
+		if cols := expression.ExtractColumnSet(newExpr); cols.Len() > 0 {
+			b.ctx.GetSessionVars().StmtCtx.ColRefFromUpdatePlan.UnionWith(cols)
 		}
 		newList = append(newList, &expression.Assignment{Col: col, ColName: name.ColName, Expr: newExpr})
 		dbName := name.DBName.L
diff --git a/pkg/planner/core/rule_eliminate_projection.go b/pkg/planner/core/rule_eliminate_projection.go
index 00e65a0446e2b..17d2334aa9d0d 100644
--- a/pkg/planner/core/rule_eliminate_projection.go
+++ b/pkg/planner/core/rule_eliminate_projection.go
@@ -82,11 +82,9 @@ func canProjectionBeEliminatedStrict(p *PhysicalProjection) bool {
 	if p.Schema().Len() != child.Schema().Len() {
 		return false
 	}
-	for _, ref := range p.SCtx().GetSessionVars().StmtCtx.ColRefFromUpdatePlan {
-		for _, one := range p.Schema().Columns {
-			if ref == one.UniqueID {
-				return false
-			}
+	for _, col := range p.Schema().Columns {
+		if p.SCtx().GetSessionVars().StmtCtx.ColRefFromUpdatePlan.Has(int(col.UniqueID)) {
+			return false
 		}
 	}
 	for i, expr := range p.Exprs {
diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go
index 3f7d8e7b559e9..6d9b9a2a50683 100644
--- a/pkg/sessionctx/stmtctx/stmtctx.go
+++ b/pkg/sessionctx/stmtctx/stmtctx.go
@@ -41,6 +41,7 @@ import (
 	"github.com/pingcap/tidb/pkg/util/execdetails"
 	"github.com/pingcap/tidb/pkg/util/hint"
 	"github.com/pingcap/tidb/pkg/util/intest"
+	"github.com/pingcap/tidb/pkg/util/intset"
 	"github.com/pingcap/tidb/pkg/util/linter/constructor"
 	"github.com/pingcap/tidb/pkg/util/memory"
 	"github.com/pingcap/tidb/pkg/util/nocopy"
@@ -362,7 +363,7 @@ type StatementContext struct {
 	// UseDynamicPruneMode indicates whether use UseDynamicPruneMode in query stmt
 	UseDynamicPruneMode bool
 	// ColRefFromPlan mark the column ref used by assignment in update statement.
-	ColRefFromUpdatePlan []int64
+	ColRefFromUpdatePlan intset.FastIntSet
 
 	// IsExplainAnalyzeDML is true if the statement is "explain analyze DML executors", before responding the explain
 	// results to the client, the transaction should be committed first. See issue #37373 for more details.
diff --git a/tests/integrationtest/t/planner/core/issuetest/planner_issue.test b/tests/integrationtest/t/planner/core/issuetest/planner_issue.test
index f849ad75a2bdc..5b38d3fa77a94 100644
--- a/tests/integrationtest/t/planner/core/issuetest/planner_issue.test
+++ b/tests/integrationtest/t/planner/core/issuetest/planner_issue.test
@@ -378,3 +378,27 @@ update t_kg74 set
       where (ref_14.c_z like 'o%fiah')))
 where (t_kg74.c_obnq8s7_s2 = case when (t_kg74.c_a1tv2 is NULL) then t_kg74.c_g else t_kg74.c_obnq8s7_s2 end
       );
+
+# https://github.com/pingcap/tidb/issues/53236
+create table t1(id int primary key, a varchar(128));
+create table t2(id int primary key, b varchar(128), c varchar(128));
+UPDATE
+    t1
+SET
+    t1.a = IFNULL(
+            (
+                SELECT
+                    t2.c
+                FROM
+                    t2
+                WHERE
+                    t2.b = t1.a
+                ORDER BY
+                    t2.b DESC,
+                    t2.c DESC
+                LIMIT
+                    1
+            ), ''
+        )
+WHERE
+    t1.id = 1;

From 32b04f4b03ba7a3c4f321ea530a6097eaaffbce9 Mon Sep 17 00:00:00 2001
From: Yiding Cui <winoros@gmail.com>
Date: Tue, 14 May 2024 20:17:03 +0800
Subject: [PATCH 2/4] fix bazel_prepare

---
 pkg/expression/BUILD.bazel         | 1 -
 pkg/sessionctx/stmtctx/BUILD.bazel | 1 +
 2 files changed, 1 insertion(+), 1 deletion(-)

diff --git a/pkg/expression/BUILD.bazel b/pkg/expression/BUILD.bazel
index 464415390e620..dc65c2f34d7c1 100644
--- a/pkg/expression/BUILD.bazel
+++ b/pkg/expression/BUILD.bazel
@@ -124,7 +124,6 @@ go_library(
         "@com_github_pingcap_failpoint//:failpoint",
         "@com_github_pingcap_tipb//go-tipb",
         "@com_github_tikv_client_go_v2//oracle",
-        "@org_golang_x_tools//container/intsets",
         "@org_uber_go_atomic//:atomic",
         "@org_uber_go_zap//:zap",
     ],
diff --git a/pkg/sessionctx/stmtctx/BUILD.bazel b/pkg/sessionctx/stmtctx/BUILD.bazel
index d766981321934..af71d44944f01 100644
--- a/pkg/sessionctx/stmtctx/BUILD.bazel
+++ b/pkg/sessionctx/stmtctx/BUILD.bazel
@@ -21,6 +21,7 @@ go_library(
         "//pkg/util/execdetails",
         "//pkg/util/hint",
         "//pkg/util/intest",
+        "//pkg/util/intset",
         "//pkg/util/linter/constructor",
         "//pkg/util/memory",
         "//pkg/util/nocopy",

From 1b6ca01ee21d3e7c48f6beff298c42ba9a732a80 Mon Sep 17 00:00:00 2001
From: Yiding Cui <winoros@gmail.com>
Date: Tue, 14 May 2024 21:16:51 +0800
Subject: [PATCH 3/4] make the condition more strict

---
 pkg/planner/core/rule_eliminate_projection.go | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/pkg/planner/core/rule_eliminate_projection.go b/pkg/planner/core/rule_eliminate_projection.go
index 17d2334aa9d0d..68cb0b1decba4 100644
--- a/pkg/planner/core/rule_eliminate_projection.go
+++ b/pkg/planner/core/rule_eliminate_projection.go
@@ -82,11 +82,6 @@ func canProjectionBeEliminatedStrict(p *PhysicalProjection) bool {
 	if p.Schema().Len() != child.Schema().Len() {
 		return false
 	}
-	for _, col := range p.Schema().Columns {
-		if p.SCtx().GetSessionVars().StmtCtx.ColRefFromUpdatePlan.Has(int(col.UniqueID)) {
-			return false
-		}
-	}
 	for i, expr := range p.Exprs {
 		col, ok := expr.(*expression.Column)
 		if !ok || !col.EqualColumn(child.Schema().Columns[i]) {
@@ -144,6 +139,11 @@ func doPhysicalProjectionElimination(p base.PhysicalPlan) base.PhysicalPlan {
 			childProj.SetSchema(p.Schema())
 		}
 	}
+	for i, col := range p.Schema().Columns {
+		if p.SCtx().GetSessionVars().StmtCtx.ColRefFromUpdatePlan.Has(int(col.UniqueID)) && !child.Schema().Columns[i].Equal(nil, col) {
+			return p
+		}
+	}
 	return child
 }
 

From 13db18e922eefa894ec36e20cbb8c16b77a89f12 Mon Sep 17 00:00:00 2001
From: Yiding Cui <winoros@gmail.com>
Date: Tue, 14 May 2024 21:36:47 +0800
Subject: [PATCH 4/4] fix test outputs

---
 .../planner/core/casetest/integration.result  | 26 ++++++++-----------
 .../integrationtest/r/planner/core/cbo.result | 15 +++++------
 .../core/rule_constant_propagation.result     | 16 +++++-------
 3 files changed, 25 insertions(+), 32 deletions(-)

diff --git a/tests/integrationtest/r/planner/core/casetest/integration.result b/tests/integrationtest/r/planner/core/casetest/integration.result
index 44c39fa562434..e0191ebbf2e12 100644
--- a/tests/integrationtest/r/planner/core/casetest/integration.result
+++ b/tests/integrationtest/r/planner/core/casetest/integration.result
@@ -1803,18 +1803,14 @@ explain format = brief update tt, (select 1 as c1 ,2 as c2 ,3 as c3, 4 as c4 uni
 id	estRows	task	access object	operator info
 Update	N/A	root		N/A
 └─Projection	0.00	root		test.tt.a, test.tt.b, test.tt.c, test.tt.d, test.tt.e, Column#18, Column#19, Column#20, Column#21
-  └─Projection	0.00	root		test.tt.a, test.tt.b, test.tt.c, test.tt.d, test.tt.e, Column#18, Column#19, Column#20, Column#21
-    └─IndexJoin	0.00	root		inner join, inner:TableReader, outer key:Column#20, Column#21, inner key:test.tt.c, test.tt.d, equal cond:eq(Column#20, test.tt.c), eq(Column#21, test.tt.d), other cond:or(or(and(eq(Column#20, 11), eq(test.tt.d, 111)), and(eq(Column#20, 22), eq(test.tt.d, 222))), or(and(eq(Column#20, 33), eq(test.tt.d, 333)), and(eq(Column#20, 44), eq(test.tt.d, 444)))), or(or(and(eq(test.tt.c, 11), eq(Column#21, 111)), and(eq(test.tt.c, 22), eq(Column#21, 222))), or(and(eq(test.tt.c, 33), eq(Column#21, 333)), and(eq(test.tt.c, 44), eq(Column#21, 444))))
-      ├─Union(Build)	0.00	root		
-      │ ├─Projection	0.00	root		Column#6->Column#18, Column#7->Column#19, Column#8->Column#20, Column#9->Column#21
-      │ │ └─Projection	0.00	root		1->Column#6, 2->Column#7, 3->Column#8, 4->Column#9
-      │ │   └─TableDual	0.00	root		rows:0
-      │ ├─Projection	0.00	root		Column#10->Column#18, Column#11->Column#19, Column#12->Column#20, Column#13->Column#21
-      │ │ └─Projection	0.00	root		2->Column#10, 3->Column#11, 4->Column#12, 5->Column#13
-      │ │   └─TableDual	0.00	root		rows:0
-      │ └─Projection	0.00	root		Column#14->Column#18, Column#15->Column#19, Column#16->Column#20, Column#17->Column#21
-      │   └─Projection	0.00	root		3->Column#14, 4->Column#15, 5->Column#16, 6->Column#17
-      │     └─TableDual	0.00	root		rows:0
-      └─TableReader(Probe)	0.00	root		data:Selection
-        └─Selection	0.00	cop[tikv]		or(or(and(eq(test.tt.c, 11), eq(test.tt.d, 111)), and(eq(test.tt.c, 22), eq(test.tt.d, 222))), or(and(eq(test.tt.c, 33), eq(test.tt.d, 333)), and(eq(test.tt.c, 44), eq(test.tt.d, 444)))), or(or(eq(test.tt.c, 11), eq(test.tt.c, 22)), or(eq(test.tt.c, 33), eq(test.tt.c, 44))), or(or(eq(test.tt.d, 111), eq(test.tt.d, 222)), or(eq(test.tt.d, 333), eq(test.tt.d, 444)))
-          └─TableRangeScan	0.00	cop[tikv]	table:tt	range: decided by [eq(test.tt.c, Column#20) eq(test.tt.d, Column#21)], keep order:false, stats:pseudo
+  └─IndexJoin	0.00	root		inner join, inner:TableReader, outer key:Column#20, Column#21, inner key:test.tt.c, test.tt.d, equal cond:eq(Column#20, test.tt.c), eq(Column#21, test.tt.d), other cond:or(or(and(eq(Column#20, 11), eq(test.tt.d, 111)), and(eq(Column#20, 22), eq(test.tt.d, 222))), or(and(eq(Column#20, 33), eq(test.tt.d, 333)), and(eq(Column#20, 44), eq(test.tt.d, 444)))), or(or(and(eq(test.tt.c, 11), eq(Column#21, 111)), and(eq(test.tt.c, 22), eq(Column#21, 222))), or(and(eq(test.tt.c, 33), eq(Column#21, 333)), and(eq(test.tt.c, 44), eq(Column#21, 444))))
+    ├─Union(Build)	0.00	root		
+    │ ├─Projection	0.00	root		1->Column#18, 2->Column#19, 3->Column#20, 4->Column#21
+    │ │ └─TableDual	0.00	root		rows:0
+    │ ├─Projection	0.00	root		2->Column#18, 3->Column#19, 4->Column#20, 5->Column#21
+    │ │ └─TableDual	0.00	root		rows:0
+    │ └─Projection	0.00	root		3->Column#18, 4->Column#19, 5->Column#20, 6->Column#21
+    │   └─TableDual	0.00	root		rows:0
+    └─TableReader(Probe)	0.00	root		data:Selection
+      └─Selection	0.00	cop[tikv]		or(or(and(eq(test.tt.c, 11), eq(test.tt.d, 111)), and(eq(test.tt.c, 22), eq(test.tt.d, 222))), or(and(eq(test.tt.c, 33), eq(test.tt.d, 333)), and(eq(test.tt.c, 44), eq(test.tt.d, 444)))), or(or(eq(test.tt.c, 11), eq(test.tt.c, 22)), or(eq(test.tt.c, 33), eq(test.tt.c, 44))), or(or(eq(test.tt.d, 111), eq(test.tt.d, 222)), or(eq(test.tt.d, 333), eq(test.tt.d, 444)))
+        └─TableRangeScan	0.00	cop[tikv]	table:tt	range: decided by [eq(test.tt.c, Column#20) eq(test.tt.d, Column#21)], keep order:false, stats:pseudo
diff --git a/tests/integrationtest/r/planner/core/cbo.result b/tests/integrationtest/r/planner/core/cbo.result
index 5a6420b0e6dc6..a14fd3c08c88e 100644
--- a/tests/integrationtest/r/planner/core/cbo.result
+++ b/tests/integrationtest/r/planner/core/cbo.result
@@ -3,14 +3,13 @@ create table t(a int, b int);
 explain update t t1, (select distinct b from t) t2 set t1.b = t2.b;
 id	estRows	task	access object	operator info
 Update_7	N/A	root		N/A
-└─Projection_9	80000000.00	root		planner__core__cbo.t.a, planner__core__cbo.t.b, planner__core__cbo.t._tidb_rowid, planner__core__cbo.t.b
-  └─HashJoin_10	80000000.00	root		CARTESIAN inner join
-    ├─HashAgg_18(Build)	8000.00	root		group by:planner__core__cbo.t.b, funcs:firstrow(planner__core__cbo.t.b)->planner__core__cbo.t.b
-    │ └─TableReader_19	8000.00	root		data:HashAgg_14
-    │   └─HashAgg_14	8000.00	cop[tikv]		group by:planner__core__cbo.t.b, 
-    │     └─TableFullScan_17	10000.00	cop[tikv]	table:t	keep order:false, stats:pseudo
-    └─TableReader_13(Probe)	10000.00	root		data:TableFullScan_12
-      └─TableFullScan_12	10000.00	cop[tikv]	table:t1	keep order:false, stats:pseudo
+└─HashJoin_10	80000000.00	root		CARTESIAN inner join
+  ├─HashAgg_18(Build)	8000.00	root		group by:planner__core__cbo.t.b, funcs:firstrow(planner__core__cbo.t.b)->planner__core__cbo.t.b
+  │ └─TableReader_19	8000.00	root		data:HashAgg_14
+  │   └─HashAgg_14	8000.00	cop[tikv]		group by:planner__core__cbo.t.b, 
+  │     └─TableFullScan_17	10000.00	cop[tikv]	table:t	keep order:false, stats:pseudo
+  └─TableReader_13(Probe)	10000.00	root		data:TableFullScan_12
+    └─TableFullScan_12	10000.00	cop[tikv]	table:t1	keep order:false, stats:pseudo
 drop table if exists tb1, tb2;
 create table tb1(a int, b int, primary key(a));
 create table tb2 (a int, b int, c int, d datetime, primary key(c),key idx_u(a));
diff --git a/tests/integrationtest/r/planner/core/rule_constant_propagation.result b/tests/integrationtest/r/planner/core/rule_constant_propagation.result
index d87e4458c8f4e..898fe2fb0660e 100644
--- a/tests/integrationtest/r/planner/core/rule_constant_propagation.result
+++ b/tests/integrationtest/r/planner/core/rule_constant_propagation.result
@@ -149,15 +149,13 @@ create table s (id int, name varchar(10));
 explain Update t, (select * from s where s.id>1) tmp set t.name=tmp.name where t.id=tmp.id;
 id	estRows	task	access object	operator info
 Update_8	N/A	root		N/A
-└─Projection_11	4166.67	root		planner__core__rule_constant_propagation.t.id, planner__core__rule_constant_propagation.t.name, planner__core__rule_constant_propagation.t._tidb_rowid, planner__core__rule_constant_propagation.s.id, planner__core__rule_constant_propagation.s.name
-  └─HashJoin_12	4166.67	root		inner join, equal:[eq(planner__core__rule_constant_propagation.t.id, planner__core__rule_constant_propagation.s.id)]
-    ├─Projection_17(Build)	3333.33	root		planner__core__rule_constant_propagation.s.id, planner__core__rule_constant_propagation.s.name
-    │ └─TableReader_20	3333.33	root		data:Selection_19
-    │   └─Selection_19	3333.33	cop[tikv]		gt(planner__core__rule_constant_propagation.s.id, 1), not(isnull(planner__core__rule_constant_propagation.s.id))
-    │     └─TableFullScan_18	10000.00	cop[tikv]	table:s	keep order:false, stats:pseudo
-    └─TableReader_16(Probe)	3333.33	root		data:Selection_15
-      └─Selection_15	3333.33	cop[tikv]		gt(planner__core__rule_constant_propagation.t.id, 1), not(isnull(planner__core__rule_constant_propagation.t.id))
-        └─TableFullScan_14	10000.00	cop[tikv]	table:t	keep order:false, stats:pseudo
+└─HashJoin_12	4166.67	root		inner join, equal:[eq(planner__core__rule_constant_propagation.t.id, planner__core__rule_constant_propagation.s.id)]
+  ├─TableReader_20(Build)	3333.33	root		data:Selection_19
+  │ └─Selection_19	3333.33	cop[tikv]		gt(planner__core__rule_constant_propagation.s.id, 1), not(isnull(planner__core__rule_constant_propagation.s.id))
+  │   └─TableFullScan_18	10000.00	cop[tikv]	table:s	keep order:false, stats:pseudo
+  └─TableReader_16(Probe)	3333.33	root		data:Selection_15
+    └─Selection_15	3333.33	cop[tikv]		gt(planner__core__rule_constant_propagation.t.id, 1), not(isnull(planner__core__rule_constant_propagation.t.id))
+      └─TableFullScan_14	10000.00	cop[tikv]	table:t	keep order:false, stats:pseudo
 drop table if exists t, s;
 create table t (id int, name varchar(10));
 create table s (id int, name varchar(10));