From af934cf36d2c87650f69e5322682d7424b5ed1a4 Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Fri, 15 Nov 2024 16:54:53 +0800 Subject: [PATCH] planner: fix column evaluator can not detect input's column-ref and thus swapping and destroying later column ref projection logic (#53794) (#57380) close pingcap/tidb#53713 --- pkg/expression/evaluator.go | 99 +++++++++++++++++++++++++++++++- pkg/expression/evaluator_test.go | 40 +++++++++++++ pkg/util/disjointset/BUILD.bazel | 5 +- pkg/util/disjointset/int_set.go | 1 + pkg/util/disjointset/set.go | 85 +++++++++++++++++++++++++++ 5 files changed, 228 insertions(+), 2 deletions(-) create mode 100644 pkg/util/disjointset/set.go diff --git a/pkg/expression/evaluator.go b/pkg/expression/evaluator.go index 904db0ab4bfdc..9bef247ec9b32 100644 --- a/pkg/expression/evaluator.go +++ b/pkg/expression/evaluator.go @@ -15,11 +15,17 @@ package expression import ( + "sync/atomic" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/disjointset" + "github.com/pingcap/tidb/pkg/util/intest" ) type columnEvaluator struct { inputIdxToOutputIdxes map[int][]int + // mergedInputIdxToOutputIdxes is only determined in runtime when saw the input chunk. + mergedInputIdxToOutputIdxes atomic.Pointer[map[int][]int] } // run evaluates "Column" expressions. @@ -27,7 +33,11 @@ type columnEvaluator struct { // // since it will change the content of the input Chunk. func (e *columnEvaluator) run(ctx EvalContext, input, output *chunk.Chunk) error { - for inputIdx, outputIdxes := range e.inputIdxToOutputIdxes { + // mergedInputIdxToOutputIdxes only can be determined in runtime when we saw the input chunk structure. + if e.mergedInputIdxToOutputIdxes.Load() == nil { + e.mergeInputIdxToOutputIdxes(input, e.inputIdxToOutputIdxes) + } + for inputIdx, outputIdxes := range *e.mergedInputIdxToOutputIdxes.Load() { if err := output.SwapColumn(outputIdxes[0], input, inputIdx); err != nil { return err } @@ -38,6 +48,93 @@ func (e *columnEvaluator) run(ctx EvalContext, input, output *chunk.Chunk) error return nil } +// mergeInputIdxToOutputIdxes merges separate inputIdxToOutputIdxes entries when column references +// are detected within the input chunk. This process ensures consistent handling of columns derived +// from the same original source. +// +// Consider the following scenario: +// +// Initial scan operation produces a column 'a': +// +// scan: a (addr: ???) +// +// This column 'a' is used in the first projection (proj1) to create two columns a1 and a2, both referencing 'a': +// +// proj1 +// / \ +// / \ +// / \ +// a1 (addr: 0xe) a2 (addr: 0xe) +// / \ +// / \ +// / \ +// proj2 proj2 +// / \ / \ +// / \ / \ +// a3 a4 a5 a6 +// +// (addr: 0xe) (addr: 0xe) (addr: 0xe) (addr: 0xe) +// +// Here, a1 and a2 share the same address (0xe), indicating they reference the same data from the original 'a'. +// +// When moving to the second projection (proj2), the system tries to project these columns further: +// - The first set (left side) consists of a3 and a4, derived from a1, both retaining the address (0xe). +// - The second set (right side) consists of a5 and a6, derived from a2, also starting with address (0xe). +// +// When proj1 is complete, the output chunk contains two columns [a1, a2], both derived from the single column 'a' from the scan. +// Since both a1 and a2 are column references with the same address (0xe), they are treated as referencing the same data. +// +// In proj2, two separate items are created: +// - <0, [0,1]>: This means the 0th input column (a1) is projected twice, into the 0th and 1st columns of the output chunk. +// - <1, [2,3]>: This means the 1st input column (a2) is projected twice, into the 2nd and 3rd columns of the output chunk. +// +// Due to the column swapping logic in each projection, after applying the <0, [0,1]> projection, +// the addresses for a1 and a2 may become swapped or invalid: +// +// proj1: a1 (addr: invalid) a2 (addr: invalid) +// +// This can lead to issues in proj2, where further operations on these columns may be unsafe: +// +// proj2: a3 (addr: 0xe) a4 (addr: 0xe) a5 (addr: ???) a6 (addr: ???) +// +// Therefore, it's crucial to identify and merge the original column references early, ensuring +// the final inputIdxToOutputIdxes mapping accurately reflects the shared origins of the data. +// For instance, <0, [0,1,2,3]> indicates that the 0th input column (original 'a') is referenced +// by all four output columns in the final output. +// +// mergeInputIdxToOutputIdxes merges inputIdxToOutputIdxes based on detected column references. +// This ensures that columns with the same reference are correctly handled in the output chunk. +func (e *columnEvaluator) mergeInputIdxToOutputIdxes(input *chunk.Chunk, inputIdxToOutputIdxes map[int][]int) { + originalDJSet := disjointset.NewSet[int](4) + flag := make([]bool, input.NumCols()) + // Detect self column-references inside the input chunk by comparing column addresses + for i := 0; i < input.NumCols(); i++ { + if flag[i] { + continue + } + for j := i + 1; j < input.NumCols(); j++ { + if input.Column(i) == input.Column(j) { + flag[j] = true + originalDJSet.Union(i, j) + } + } + } + // Merge inputIdxToOutputIdxes based on the detected column references. + newInputIdxToOutputIdxes := make(map[int][]int, len(inputIdxToOutputIdxes)) + for inputIdx := range inputIdxToOutputIdxes { + // Root idx is internal offset, not the right column index. + originalRootIdx := originalDJSet.FindRoot(inputIdx) + originalVal, ok := originalDJSet.FindVal(originalRootIdx) + intest.Assert(ok) + mergedOutputIdxes := newInputIdxToOutputIdxes[originalVal] + mergedOutputIdxes = append(mergedOutputIdxes, inputIdxToOutputIdxes[inputIdx]...) + newInputIdxToOutputIdxes[originalVal] = mergedOutputIdxes + } + // Update the merged inputIdxToOutputIdxes automatically. + // Once failed, it means other worker has done this job at meantime. + e.mergedInputIdxToOutputIdxes.CompareAndSwap(nil, &newInputIdxToOutputIdxes) +} + type defaultEvaluator struct { outputIdxes []int exprs []Expression diff --git a/pkg/expression/evaluator_test.go b/pkg/expression/evaluator_test.go index b9a08f3c264fe..e77279f4e318e 100644 --- a/pkg/expression/evaluator_test.go +++ b/pkg/expression/evaluator_test.go @@ -15,6 +15,7 @@ package expression import ( + "slices" "testing" "time" @@ -606,3 +607,42 @@ func TestMod(t *testing.T) { require.NoError(t, err) require.Equal(t, types.NewDatum(1.5), r) } + +func TestMergeInputIdxToOutputIdxes(t *testing.T) { + ctx := createContext(t) + inputIdxToOutputIdxes := make(map[int][]int) + // input 0th should be column referred as 0th and 1st in output columns. + inputIdxToOutputIdxes[0] = []int{0, 1} + // input 1th should be column referred as 2nd and 3rd in output columns. + inputIdxToOutputIdxes[1] = []int{2, 3} + columnEval := columnEvaluator{inputIdxToOutputIdxes: inputIdxToOutputIdxes} + + input := chunk.NewEmptyChunk([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong)}) + input.AppendInt64(0, 99) + // input chunk's 0th and 1st are column referred itself. + input.MakeRef(0, 1) + + // chunk: col1 <---(ref) col2 + // ____________/ \___________/ \___ + // proj: col1 col2 col3 col4 + // + // original case after inputIdxToOutputIdxes[0], the original col2 will be nil pointer + // cause consecutive col3,col4 ref projection are invalid. + // + // after fix, the new inputIdxToOutputIdxes should be: inputIdxToOutputIdxes[0]: {0, 1, 2, 3} + + output := chunk.NewEmptyChunk([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong), + types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong)}) + + err := columnEval.run(ctx, input, output) + require.NoError(t, err) + // all four columns are column-referred, pointing to the first one. + require.Equal(t, output.Column(0), output.Column(1)) + require.Equal(t, output.Column(1), output.Column(2)) + require.Equal(t, output.Column(2), output.Column(3)) + require.Equal(t, output.GetRow(0).GetInt64(0), int64(99)) + + require.Equal(t, len(*columnEval.mergedInputIdxToOutputIdxes.Load()), 1) + slices.Sort((*columnEval.mergedInputIdxToOutputIdxes.Load())[0]) + require.Equal(t, (*columnEval.mergedInputIdxToOutputIdxes.Load())[0], []int{0, 1, 2, 3}) +} diff --git a/pkg/util/disjointset/BUILD.bazel b/pkg/util/disjointset/BUILD.bazel index 941410ed9d54b..293f0194a0760 100644 --- a/pkg/util/disjointset/BUILD.bazel +++ b/pkg/util/disjointset/BUILD.bazel @@ -2,7 +2,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "disjointset", - srcs = ["int_set.go"], + srcs = [ + "int_set.go", + "set.go", + ], importpath = "github.com/pingcap/tidb/pkg/util/disjointset", visibility = ["//visibility:public"], ) diff --git a/pkg/util/disjointset/int_set.go b/pkg/util/disjointset/int_set.go index 05846e3840850..1396c634965aa 100644 --- a/pkg/util/disjointset/int_set.go +++ b/pkg/util/disjointset/int_set.go @@ -38,6 +38,7 @@ func (m *IntSet) FindRoot(a int) int { if a == m.parent[a] { return a } + // Path compression, which leads the time complexity to the inverse Ackermann function. m.parent[a] = m.FindRoot(m.parent[a]) return m.parent[a] } diff --git a/pkg/util/disjointset/set.go b/pkg/util/disjointset/set.go new file mode 100644 index 0000000000000..9e8eee37f7677 --- /dev/null +++ b/pkg/util/disjointset/set.go @@ -0,0 +1,85 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package disjointset + +// Set is the universal implementation of a disjoint set. +// It's designed for sparse cases or non-integer types. +// If you are dealing with continuous integers, you should use SimpleIntSet to avoid the cost of a hash map. +// We hash the original value to an integer index and then apply the core disjoint set algorithm. +// Time complexity: the union operation has an inverse Ackermann function time complexity, which is very close to O(1). +type Set[T comparable] struct { + parent []int + val2Idx map[T]int + idx2Val map[int]T + tailIdx int +} + +// NewSet creates a disjoint set. +func NewSet[T comparable](size int) *Set[T] { + return &Set[T]{ + parent: make([]int, 0, size), + val2Idx: make(map[T]int, size), + idx2Val: make(map[int]T, size), + tailIdx: 0, + } +} + +func (s *Set[T]) findRootOriginalVal(a T) int { + idx, ok := s.val2Idx[a] + if !ok { + s.parent = append(s.parent, s.tailIdx) + s.val2Idx[a] = s.tailIdx + s.tailIdx++ + s.idx2Val[s.tailIdx-1] = a + return s.tailIdx - 1 + } + return s.findRootInternal(idx) +} + +// findRoot is an internal implementation. Call it inside findRootOriginalVal. +func (s *Set[T]) findRootInternal(a int) int { + if s.parent[a] != a { + // Path compression, which leads the time complexity to the inverse Ackermann function. + s.parent[a] = s.findRootInternal(s.parent[a]) + } + return s.parent[a] +} + +// InSameGroup checks whether a and b are in the same group. +func (s *Set[T]) InSameGroup(a, b T) bool { + return s.findRootOriginalVal(a) == s.findRootOriginalVal(b) +} + +// Union joins two sets in the disjoint set. +func (s *Set[T]) Union(a, b T) { + rootA := s.findRootOriginalVal(a) + rootB := s.findRootOriginalVal(b) + // take b as successor, respect the rootA as the root of the new set. + if rootA != rootB { + s.parent[rootB] = rootA + } +} + +// FindRoot finds the root of the set that contains a. +func (s *Set[T]) FindRoot(a T) int { + // if a is not in the set, assign a new index to it. + return s.findRootOriginalVal(a) +} + +// FindVal finds the value of the set corresponding to the index. +func (s *Set[T]) FindVal(idx int) (T, bool) { + v, ok := s.idx2Val[s.findRootInternal(idx)] + return v, ok +}