Skip to content

Commit

Permalink
executor: introduce tagged pointer in hash join v2 (#55470)
Browse files Browse the repository at this point in the history
ref #53127
  • Loading branch information
windtalker authored Sep 4, 2024
1 parent eb19c10 commit 5ffe4b1
Show file tree
Hide file tree
Showing 12 changed files with 280 additions and 66 deletions.
2 changes: 2 additions & 0 deletions pkg/executor/join/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ go_library(
"joiner.go",
"merge_join.go",
"outer_join_probe.go",
"tagged_ptr.go",
],
importpath = "github.com/pingcap/tidb/pkg/executor/join",
visibility = ["//visibility:public"],
Expand Down Expand Up @@ -77,6 +78,7 @@ go_test(
"merge_join_test.go",
"right_outer_join_probe_test.go",
"row_table_builder_test.go",
"tagged_ptr_test.go",
],
embed = [":join"],
flaky = True,
Expand Down
19 changes: 14 additions & 5 deletions pkg/executor/join/base_join_probe.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ type baseJoinProbe struct {
selRows []int
usedRows []int
// matchedRowsHeaders, serializedKeys is indexed by logical row index
matchedRowsHeaders []uintptr // the start address of each matched rows
serializedKeys [][]byte // used for save serialized keys
matchedRowsHeaders []taggedPtr // the start address of each matched rows
matchedRowsHashValue []uint64 // the hash value of each matched rows
serializedKeys [][]byte // used for save serialized keys
// filterVector and nullKeyVector is indexed by physical row index because the return vector of VectorizedFilter is based on physical row index
filterVector []bool // if there is filter before probe, filterVector saves the filter result
nullKeyVector []bool // nullKeyVector[i] = true if any of the key is null
Expand Down Expand Up @@ -184,7 +185,12 @@ func (j *baseJoinProbe) SetChunkForProbe(chk *chunk.Chunk) (err error) {
if cap(j.matchedRowsHeaders) >= logicalRows {
j.matchedRowsHeaders = j.matchedRowsHeaders[:logicalRows]
} else {
j.matchedRowsHeaders = make([]uintptr, logicalRows)
j.matchedRowsHeaders = make([]taggedPtr, logicalRows)
}
if cap(j.matchedRowsHashValue) >= logicalRows {
j.matchedRowsHashValue = j.matchedRowsHashValue[:logicalRows]
} else {
j.matchedRowsHashValue = make([]uint64, logicalRows)
}
for i := 0; i < int(j.ctx.partitionNumber); i++ {
j.hashValues[i] = j.hashValues[i][:0]
Expand Down Expand Up @@ -234,20 +240,22 @@ func (j *baseJoinProbe) SetChunkForProbe(chk *chunk.Chunk) (err error) {
if (j.filterVector != nil && !j.filterVector[physicalRowIndex]) || (j.nullKeyVector != nil && j.nullKeyVector[physicalRowIndex]) {
// explicit set the matchedRowsHeaders[logicalRowIndex] to nil to indicate there is no matched rows
j.matchedRowsHeaders[logicalRowIndex] = 0
j.matchedRowsHashValue[logicalRowIndex] = 0
continue
}
hash.Reset()
// As the golang doc described, `Hash.Write` never returns an error.
// See https://golang.org/pkg/hash/#Hash
_, _ = hash.Write(j.serializedKeys[logicalRowIndex])
hashValue := hash.Sum64()
j.matchedRowsHashValue[logicalRowIndex] = hashValue
partIndex := hashValue >> j.ctx.partitionMaskOffset
j.hashValues[partIndex] = append(j.hashValues[partIndex], posAndHashValue{hashValue: hashValue, pos: logicalRowIndex})
}
j.currentProbeRow = 0
for i := 0; i < int(j.ctx.partitionNumber); i++ {
for index := range j.hashValues[i] {
j.matchedRowsHeaders[j.hashValues[i][index].pos] = j.ctx.hashTableContext.hashTable.tables[i].lookup(j.hashValues[i][index].hashValue)
j.matchedRowsHeaders[j.hashValues[i][index].pos] = j.ctx.hashTableContext.lookup(i, j.hashValues[i][index].hashValue)
}
}
return
Expand Down Expand Up @@ -527,7 +535,8 @@ func NewJoinProbe(ctx *HashJoinCtxV2, workID uint, joinType logicalop.JoinType,
}
}
base.cachedBuildRows = make([]*matchedRowInfo, 0, batchBuildRowSize)
base.matchedRowsHeaders = make([]uintptr, 0, chunk.InitialCapacity)
base.matchedRowsHeaders = make([]taggedPtr, 0, chunk.InitialCapacity)
base.matchedRowsHashValue = make([]uint64, 0, chunk.InitialCapacity)
base.selRows = make([]int, 0, chunk.InitialCapacity)
for i := 0; i < chunk.InitialCapacity; i++ {
base.selRows = append(base.selRows, i)
Expand Down
12 changes: 8 additions & 4 deletions pkg/executor/join/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,26 @@ import (

func BenchmarkHashTableBuild(b *testing.B) {
b.StopTimer()
rowTable, err := createRowTable(3000000)
rowTable, tagBits, err := createRowTable(3000000)
if err != nil {
b.Fatal(err)
}
tagHelper := &tagPtrHelper{}
tagHelper.init(tagBits)
subTable := newSubTable(rowTable)
segmentCount := len(rowTable.segments)
b.StartTimer()
subTable.build(0, segmentCount)
subTable.build(0, segmentCount, tagHelper)
}

func BenchmarkHashTableConcurrentBuild(b *testing.B) {
b.StopTimer()
rowTable, err := createRowTable(3000000)
rowTable, tagBits, err := createRowTable(3000000)
if err != nil {
b.Fatal(err)
}
tagHelper := &tagPtrHelper{}
tagHelper.init(tagBits)
subTable := newSubTable(rowTable)
segmentCount := len(rowTable.segments)
buildThreads := 3
Expand All @@ -52,7 +56,7 @@ func BenchmarkHashTableConcurrentBuild(b *testing.B) {
segmentEnd = segmentCount
}
wg.Run(func() {
subTable.build(segmentStart, segmentEnd)
subTable.build(segmentStart, segmentEnd, tagHelper)
})
}
wg.Wait()
Expand Down
20 changes: 18 additions & 2 deletions pkg/executor/join/hash_join_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,25 @@ type hashTableContext struct {
// its own rowTable
rowTables [][]*rowTable
hashTable *hashTableV2
tagHelper *tagPtrHelper
memoryTracker *memory.Tracker
}

func (htc *hashTableContext) reset() {
htc.rowTables = nil
htc.hashTable = nil
htc.tagHelper = nil
htc.memoryTracker.Detach()
}

func (htc *hashTableContext) build(task *buildTask) {
htc.hashTable.tables[task.partitionIdx].build(task.segStartIdx, task.segEndIdx, htc.tagHelper)
}

func (htc *hashTableContext) lookup(partitionIndex int, hashValue uint64) taggedPtr {
return htc.hashTable.tables[partitionIndex].lookup(hashValue, htc.tagHelper)
}

func (htc *hashTableContext) getCurrentRowSegment(workerID, partitionID int, tableMeta *TableMeta, allowCreate bool, firstSegSizeHint uint) *rowTableSegment {
if htc.rowTables[workerID][partitionID] == nil {
htc.rowTables[workerID][partitionID] = newRowTable(tableMeta)
Expand All @@ -100,6 +110,7 @@ func (htc *hashTableContext) finalizeCurrentSeg(workerID, partitionID int, build
seg := htc.getCurrentRowSegment(workerID, partitionID, nil, false, 0)
builder.rowNumberInCurrentRowTableSeg[partitionID] = 0
failpoint.Inject("finalizeCurrentSegPanic", nil)
seg.initTaggedBits()
seg.finalized = true
htc.memoryTracker.Consume(seg.totalUsedBytes())
}
Expand All @@ -119,9 +130,15 @@ func (htc *hashTableContext) mergeRowTablesToHashTable(tableMeta *TableMeta, par
totalSegmentCnt += len(rt.segments)
}
}
taggedBits := uint8(maxTaggedBits)
for i := 0; i < int(partitionNumber); i++ {
for _, seg := range rowTables[i].segments {
taggedBits = min(taggedBits, seg.taggedBits)
}
htc.hashTable.tables[i] = newSubTable(rowTables[i])
}
htc.tagHelper = &tagPtrHelper{}
htc.tagHelper.init(taggedBits)
htc.rowTables = nil
return totalSegmentCnt
}
Expand Down Expand Up @@ -833,8 +850,7 @@ func (w *BuildWorkerV2) buildHashTable(taskCh chan *buildTask) error {
}()
for task := range taskCh {
start := time.Now()
partIdx, segStartIdx, segEndIdx := task.partitionIdx, task.segStartIdx, task.segEndIdx
w.HashJoinCtx.hashTableContext.hashTable.tables[partIdx].build(segStartIdx, segEndIdx)
w.HashJoinCtx.hashTableContext.build(task)
failpoint.Inject("buildHashTablePanic", nil)
cost += int64(time.Since(start))
}
Expand Down
57 changes: 37 additions & 20 deletions pkg/executor/join/hash_table_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,32 @@ import (
)

type subTable struct {
rowData *rowTable
hashTable []uintptr
rowData *rowTable
// the taggedPtr is used to save the row address, during hash join build stage
// it will convert the chunk data into row format, each row there is an unsafe.Pointer
// pointing the start address of the row. The unsafe.Pointer will be converted to
// taggedPtr and saved in hashTable.
// Generally speaking it is unsafe or even illegal in go to save unsafe.Pointer
// into uintptr, and later convert uintptr back to unsafe.Pointer since after save
// the value of unsafe.Pointer into uintptr, it has no pointer semantics, and may
// become invalid after GC. But it is ok to do this in hash join so far because
// 1. the check of heapObjectsCanMove makes sure that if the object is in heap, the address will not be changed after GC
// 2. row address only points to a valid address in `rowTableSegment.rawData`. `rawData` is a slice in `rowTableSegment`, and it will be used by multiple goroutines,
// and its size will be runtime expanded, this kind of slice will always be allocated in heap
hashTable []taggedPtr
posMask uint64
isRowTableEmpty bool
isHashTableEmpty bool
}

func (st *subTable) lookup(hashValue uint64) uintptr {
return st.hashTable[hashValue&st.posMask]
func (st *subTable) lookup(hashValue uint64, tagHelper *tagPtrHelper) taggedPtr {
ret := st.hashTable[hashValue&st.posMask]
hashTagValue := tagHelper.getTaggedValue(hashValue)
if uint64(ret)&hashTagValue != hashTagValue {
// if tag value not match, the key will not be matched
return 0
}
return ret
}

func nextPowerOfTwo(value uint64) uint64 {
Expand Down Expand Up @@ -56,44 +73,48 @@ func newSubTable(table *rowTable) *subTable {
ret.isHashTableEmpty = true
}
hashTableLength := max(nextPowerOfTwo(table.validKeyCount()), uint64(32))
ret.hashTable = make([]uintptr, hashTableLength)
ret.hashTable = make([]taggedPtr, hashTableLength)
ret.posMask = hashTableLength - 1
return ret
}

func (st *subTable) updateHashValue(pos uint64, rowAddress unsafe.Pointer) {
prev := *(*unsafe.Pointer)(unsafe.Pointer(&st.hashTable[pos]))
*(*unsafe.Pointer)(unsafe.Pointer(&st.hashTable[pos])) = rowAddress
func (st *subTable) updateHashValue(hashValue uint64, rowAddress unsafe.Pointer, tagHelper *tagPtrHelper) {
pos := hashValue & st.posMask
prev := st.hashTable[pos]
tagValue := tagHelper.getTaggedValue(hashValue | uint64(prev))
taggedAddress := tagHelper.toTaggedPtr(tagValue, rowAddress)
st.hashTable[pos] = taggedAddress
setNextRowAddress(rowAddress, prev)
}

func (st *subTable) atomicUpdateHashValue(pos uint64, rowAddress unsafe.Pointer) {
func (st *subTable) atomicUpdateHashValue(hashValue uint64, rowAddress unsafe.Pointer, tagHelper *tagPtrHelper) {
pos := hashValue & st.posMask
for {
prev := atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&st.hashTable[pos])))
if atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&st.hashTable[pos])), prev, rowAddress) {
prev := taggedPtr(atomic.LoadUintptr((*uintptr)(unsafe.Pointer(&st.hashTable[pos]))))
tagValue := tagHelper.getTaggedValue(hashValue | uint64(prev))
taggedAddress := tagHelper.toTaggedPtr(tagValue, rowAddress)
if atomic.CompareAndSwapUintptr((*uintptr)(unsafe.Pointer(&st.hashTable[pos])), uintptr(prev), uintptr(taggedAddress)) {
setNextRowAddress(rowAddress, prev)
break
}
}
}

func (st *subTable) build(startSegmentIndex int, endSegmentIndex int) {
func (st *subTable) build(startSegmentIndex int, endSegmentIndex int, tagHelper *tagPtrHelper) {
if startSegmentIndex == 0 && endSegmentIndex == len(st.rowData.segments) {
for i := startSegmentIndex; i < endSegmentIndex; i++ {
for _, index := range st.rowData.segments[i].validJoinKeyPos {
rowAddress := st.rowData.segments[i].getRowPointer(index)
hashValue := st.rowData.segments[i].hashValues[index]
pos := hashValue & st.posMask
st.updateHashValue(pos, rowAddress)
st.updateHashValue(hashValue, rowAddress, tagHelper)
}
}
} else {
for i := startSegmentIndex; i < endSegmentIndex; i++ {
for _, index := range st.rowData.segments[i].validJoinKeyPos {
rowAddress := st.rowData.segments[i].getRowPointer(index)
hashValue := st.rowData.segments[i].hashValues[index]
pos := hashValue & st.posMask
st.atomicUpdateHashValue(pos, rowAddress)
st.atomicUpdateHashValue(hashValue, rowAddress, tagHelper)
}
}
}
Expand Down Expand Up @@ -205,7 +226,3 @@ func (jht *hashTableV2) totalRowCount() uint64 {
}
return ret
}

func (jht *hashTableV2) buildHashTableForTest(partitionIndex int, startSegmentIndex int, segmentStep int) {
jht.tables[partitionIndex].build(startSegmentIndex, segmentStep)
}
Loading

0 comments on commit 5ffe4b1

Please sign in to comment.