Skip to content

Commit

Permalink
sessionctx/stmtctx: avoid unlock of unlocked mutex panic on Statement…
Browse files Browse the repository at this point in the history
…Context
  • Loading branch information
chibiegg committed Dec 31, 2024
1 parent 42d4fae commit 37ebbf0
Showing 1 changed file with 69 additions and 46 deletions.
115 changes: 69 additions & 46 deletions pkg/sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,9 @@ const (

// GetOrStoreStmtCache gets the cached value of the given key if it exists, otherwise stores the value.
func (sc *StatementContext) GetOrStoreStmtCache(key StmtCacheKey, value any) any {
sc.stmtCache.mu.Lock()
defer sc.stmtCache.mu.Unlock()
mu := &sc.stmtCache.mu
mu.Lock()
defer mu.Unlock()
if sc.stmtCache.data == nil {
sc.stmtCache.data = make(map[StmtCacheKey]any)
}
Expand All @@ -604,8 +605,9 @@ func (sc *StatementContext) GetOrStoreStmtCache(key StmtCacheKey, value any) any

// GetOrEvaluateStmtCache gets the cached value of the given key if it exists, otherwise calculate the value.
func (sc *StatementContext) GetOrEvaluateStmtCache(key StmtCacheKey, valueEvaluator func() (any, error)) (any, error) {
sc.stmtCache.mu.Lock()
defer sc.stmtCache.mu.Unlock()
mu := &sc.stmtCache.mu
mu.Lock()
defer mu.Unlock()
if sc.stmtCache.data == nil {
sc.stmtCache.data = make(map[StmtCacheKey]any)
}
Expand All @@ -621,15 +623,17 @@ func (sc *StatementContext) GetOrEvaluateStmtCache(key StmtCacheKey, valueEvalua

// ResetInStmtCache resets the cache of given key.
func (sc *StatementContext) ResetInStmtCache(key StmtCacheKey) {
sc.stmtCache.mu.Lock()
defer sc.stmtCache.mu.Unlock()
mu := &sc.stmtCache.mu
mu.Lock()
defer mu.Unlock()
delete(sc.stmtCache.data, key)
}

// ResetStmtCache resets all cached values.
func (sc *StatementContext) ResetStmtCache() {
sc.stmtCache.mu.Lock()
defer sc.stmtCache.mu.Unlock()
mu := &sc.stmtCache.mu
mu.Lock()
defer mu.Unlock()
sc.stmtCache.data = make(map[StmtCacheKey]any)
}

Expand Down Expand Up @@ -798,120 +802,137 @@ func (sc *StatementContext) AddAffectedRows(rows uint64) {
// For compatibility with MySQL, not add the affected row cause by the foreign key trigger.
return
}
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
sc.mu.affectedRows += rows
}

// SetAffectedRows sets affected rows.
func (sc *StatementContext) SetAffectedRows(rows uint64) {
sc.mu.Lock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
sc.mu.affectedRows = rows
sc.mu.Unlock()
}

// AffectedRows gets affected rows.
func (sc *StatementContext) AffectedRows() uint64 {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
return sc.mu.affectedRows
}

// FoundRows gets found rows.
func (sc *StatementContext) FoundRows() uint64 {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
return sc.mu.foundRows
}

// AddFoundRows adds found rows.
func (sc *StatementContext) AddFoundRows(rows uint64) {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
sc.mu.foundRows += rows
}

// RecordRows is used to generate info message
func (sc *StatementContext) RecordRows() uint64 {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
return sc.mu.records
}

// AddRecordRows adds record rows.
func (sc *StatementContext) AddRecordRows(rows uint64) {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
sc.mu.records += rows
}

// DeletedRows is used to generate info message
func (sc *StatementContext) DeletedRows() uint64 {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
return sc.mu.deleted
}

// AddDeletedRows adds record rows.
func (sc *StatementContext) AddDeletedRows(rows uint64) {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
sc.mu.deleted += rows
}

// UpdatedRows is used to generate info message
func (sc *StatementContext) UpdatedRows() uint64 {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
return sc.mu.updated
}

// AddUpdatedRows adds updated rows.
func (sc *StatementContext) AddUpdatedRows(rows uint64) {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
sc.mu.updated += rows
}

// CopiedRows is used to generate info message
func (sc *StatementContext) CopiedRows() uint64 {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
return sc.mu.copied
}

// AddCopiedRows adds copied rows.
func (sc *StatementContext) AddCopiedRows(rows uint64) {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
sc.mu.copied += rows
}

// TouchedRows is used to generate info message
func (sc *StatementContext) TouchedRows() uint64 {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
return sc.mu.touched
}

// AddTouchedRows adds touched rows.
func (sc *StatementContext) AddTouchedRows(rows uint64) {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
sc.mu.touched += rows
}

// GetMessage returns the extra message of the last executed command, if there is no message, it returns empty string
func (sc *StatementContext) GetMessage() string {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
return sc.mu.message
}

// SetMessage sets the info message generated by some commands
func (sc *StatementContext) SetMessage(msg string) {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
sc.mu.message = msg
}

Expand Down Expand Up @@ -995,8 +1016,9 @@ func (sc *StatementContext) AppendExtraError(warn error) {

// resetMuForRetry resets the changed states of sc.mu during execution.
func (sc *StatementContext) resetMuForRetry() {
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
sc.mu.affectedRows = 0
sc.mu.foundRows = 0
sc.mu.records = 0
Expand Down Expand Up @@ -1025,8 +1047,9 @@ func (sc *StatementContext) ResetForRetry() {
// GetExecDetails gets the execution details for the statement.
func (sc *StatementContext) GetExecDetails() execdetails.ExecDetails {
var details execdetails.ExecDetails
sc.mu.Lock()
defer sc.mu.Unlock()
mu := &sc.mu
mu.Lock()
defer mu.Unlock()
details = sc.SyncExecDetails.GetExecDetails()
details.LockKeysDuration = time.Duration(atomic.LoadInt64(&sc.LockKeysDuration))
return details
Expand Down

0 comments on commit 37ebbf0

Please sign in to comment.