diff --git a/session/session.go b/session/session.go index 96d2851653a98..ca9d361893faa 100644 --- a/session/session.go +++ b/session/session.go @@ -1219,22 +1219,32 @@ func (s *session) isTxnRetryableError(err error) bool { return kv.IsTxnRetryableError(err) || domain.ErrInfoSchemaChanged.Equal(err) } +func isEndTxnStmt(stmt ast.StmtNode, vars *variable.SessionVars) (bool, error) { + switch n := stmt.(type) { + case *ast.RollbackStmt, *ast.CommitStmt: + return true, nil + case *ast.ExecuteStmt: + ps, err := plannercore.GetPreparedStmt(n, vars) + if err != nil { + return false, err + } + return isEndTxnStmt(ps.PreparedAst.Stmt, vars) + } + return false, nil +} + func (s *session) checkTxnAborted(stmt sqlexec.Statement) error { - var err error - if atomic.LoadUint32(&s.GetSessionVars().TxnCtx.LockExpire) > 0 { - err = kv.ErrLockExpire - } else { + if atomic.LoadUint32(&s.GetSessionVars().TxnCtx.LockExpire) == 0 { return nil } // If the transaction is aborted, the following statements do not need to execute, except `commit` and `rollback`, // because they are used to finish the aborted transaction. - if _, ok := stmt.(*executor.ExecStmt).StmtNode.(*ast.CommitStmt); ok { - return nil - } - if _, ok := stmt.(*executor.ExecStmt).StmtNode.(*ast.RollbackStmt); ok { + if ok, err := isEndTxnStmt(stmt.(*executor.ExecStmt).StmtNode, s.sessionVars); err == nil && ok { return nil + } else if err != nil { + return err } - return err + return kv.ErrLockExpire } func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { diff --git a/tests/realtikvtest/pessimistictest/BUILD.bazel b/tests/realtikvtest/pessimistictest/BUILD.bazel index d728cd27e686a..86dfdb4cd2dd3 100644 --- a/tests/realtikvtest/pessimistictest/BUILD.bazel +++ b/tests/realtikvtest/pessimistictest/BUILD.bazel @@ -35,6 +35,7 @@ go_test( "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_stretchr_testify//require", + "@com_github_tikv_client_go_v2//config", "@com_github_tikv_client_go_v2//oracle", "@com_github_tikv_client_go_v2//testutils", "@com_github_tikv_client_go_v2//tikv", diff --git a/tests/realtikvtest/pessimistictest/pessimistic_test.go b/tests/realtikvtest/pessimistictest/pessimistic_test.go index 0db67fefc9fff..e322a3d803286 100644 --- a/tests/realtikvtest/pessimistictest/pessimistic_test.go +++ b/tests/realtikvtest/pessimistictest/pessimistic_test.go @@ -50,6 +50,7 @@ import ( "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/deadlockhistory" "github.com/stretchr/testify/require" + tikvcfg "github.com/tikv/client-go/v2/config" "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/testutils" "github.com/tikv/client-go/v2/tikv" @@ -3793,3 +3794,37 @@ func TestIssue42937(t *testing.T) { "5 11", )) } + +func TestEndTxnOnLockExpire(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("prepare ps_commit from 'commit'") + tk.MustExec("prepare ps_rollback from 'rollback'") + + defer setLockTTL(300).restore() + defer tikvcfg.UpdateGlobal(func(conf *tikvcfg.Config) { + conf.MaxTxnTTL = 500 + })() + + for _, tt := range []struct { + name string + endTxnSQL string + }{ + {"CommitTxt", "commit"}, + {"CommitBin", "execute ps_commit"}, + {"RollbackTxt", "rollback"}, + {"RollbackBin", "execute ps_rollback"}, + } { + t.Run(tt.name, func(t *testing.T) { + tk.Exec("delete from t") + tk.Exec("insert into t values (1, 1)") + tk.Exec("begin pessimistic") + tk.Exec("update t set b = 10 where a = 1") + time.Sleep(time.Second) + tk.MustContainErrMsg("select * from t", "TTL manager has timed out") + tk.MustExec(tt.endTxnSQL) + }) + } +}