diff --git a/pkg/datasource/sql/conn.go b/pkg/datasource/sql/conn.go index 99d4453f5..9e441ccfc 100644 --- a/pkg/datasource/sql/conn.go +++ b/pkg/datasource/sql/conn.go @@ -31,7 +31,6 @@ import ( // Conn is assumed to be stateful. type Conn struct { - txType types.TransactionType res *DBResource txCtx *types.TransactionContext targetConn driver.Conn @@ -47,8 +46,8 @@ func (c *Conn) ResetSession(ctx context.Context) error { return driver.ErrSkip } - c.txType = types.Local - c.txCtx = nil + c.autoCommit = true + c.txCtx = types.NewTxCtx() return conn.ResetSession(ctx) } @@ -221,26 +220,29 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam return c.Query(query, values) } - executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query) - if err != nil { - return nil, err - } + ret, err := c.createNewTxOnExecIfNeed(func() (types.ExecResult, error) { + executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query) + if err != nil { + return nil, err + } - execCtx := &types.ExecContext{ - TxCtx: c.txCtx, - Query: query, - NamedValues: args, - } + execCtx := &types.ExecContext{ + TxCtx: c.txCtx, + Query: query, + NamedValues: args, + } - ret, err := executor.ExecWithNamedValue(ctx, execCtx, - func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { - ret, err := conn.QueryContext(ctx, query, args) - if err != nil { - return nil, err - } + return executor.ExecWithNamedValue(ctx, execCtx, + func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { + ret, err := conn.QueryContext(ctx, query, args) + if err != nil { + return nil, err + } + + return types.NewResult(types.WithRows(ret)), nil + }) + }) - return types.NewResult(types.WithRows(ret)), nil - }) if err != nil { return nil, err } @@ -252,6 +254,8 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam // // Deprecated: Drivers should implement ConnBeginTx instead (or additionally). func (c *Conn) Begin() (driver.Tx, error) { + c.autoCommit = false + tx, err := c.targetConn.Begin() if err != nil { return nil, err @@ -271,8 +275,11 @@ func (c *Conn) Begin() (driver.Tx, error) { } // BeginTx Open a transaction and judge whether the current transaction needs to open a -// global transaction according to ctx. If so, it needs to be included in the transaction management of seata +// +// global transaction according to ctx. If so, it needs to be included in the transaction management of seata func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + c.autoCommit = false + if conn, ok := c.targetConn.(driver.ConnBeginTx); ok { tx, err := conn.BeginTx(ctx, opts) if err != nil { @@ -309,7 +316,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e // Drivers must ensure all network calls made by Close // do not block indefinitely (e.g. apply a timeout). func (c *Conn) Close() error { - c.txCtx = nil + c.txCtx = types.NewTxCtx() return c.targetConn.Close() } diff --git a/pkg/datasource/sql/conn_at.go b/pkg/datasource/sql/conn_at.go index e511c2ac0..3caafad41 100644 --- a/pkg/datasource/sql/conn_at.go +++ b/pkg/datasource/sql/conn_at.go @@ -25,25 +25,38 @@ import ( "github.com/seata/seata-go/pkg/tm" ) +// ATConn Database connection proxy object under XA transaction model +// Conn is assumed to be stateful. type ATConn struct { *Conn } func (c *ATConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - if c.createTxCtxIfAbsent(ctx) { + if c.createOnceTxContext(ctx) { defer func() { - c.txCtx = nil + c.txCtx = types.NewTxCtx() }() } return c.Conn.PrepareContext(ctx, query) } +// QueryContext +func (c *ATConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if c.createOnceTxContext(ctx) { + defer func() { + c.txCtx = types.NewTxCtx() + }() + } + + return c.Conn.QueryContext(ctx, query, args) +} + // ExecContext func (c *ATConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - if c.createTxCtxIfAbsent(ctx) { + if c.createOnceTxContext(ctx) { defer func() { - c.txCtx = nil + c.txCtx = types.NewTxCtx() }() } @@ -52,33 +65,33 @@ func (c *ATConn) ExecContext(ctx context.Context, query string, args []driver.Na // BeginTx func (c *ATConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + c.autoCommit = false + c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType c.txCtx.TxOpt = opts if IsGlobalTx(ctx) { c.txCtx.XaID = tm.GetXID(ctx) - c.txCtx.TransType = c.txType + c.txCtx.TransType = types.ATMode + } + + tx, err := c.Conn.BeginTx(ctx, opts) + if err != nil { + return nil, err } - return c.Conn.BeginTx(ctx, opts) + return &ATTx{tx: tx.(*Tx)}, nil } -func (c *ATConn) createTxCtxIfAbsent(ctx context.Context) bool { - var onceTx bool +func (c *ATConn) createOnceTxContext(ctx context.Context) bool { + onceTx := IsGlobalTx(ctx) && c.autoCommit - if IsGlobalTx(ctx) && c.txCtx == nil { + if onceTx { c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType c.txCtx.XaID = tm.GetXID(ctx) c.txCtx.TransType = types.ATMode - c.autoCommit = true - onceTx = true - } - - if c.txCtx == nil { - c.txCtx = types.NewTxCtx() - onceTx = true } return onceTx diff --git a/pkg/datasource/sql/conn_at_test.go b/pkg/datasource/sql/conn_at_test.go index bb03d74ee..c962ffb28 100644 --- a/pkg/datasource/sql/conn_at_test.go +++ b/pkg/datasource/sql/conn_at_test.go @@ -20,6 +20,7 @@ package sql import ( "context" "database/sql" + "database/sql/driver" "sync/atomic" "testing" @@ -32,9 +33,8 @@ import ( "github.com/stretchr/testify/assert" ) -func TestATConn_ExecContext(t *testing.T) { +func initAtConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQLInterceptor, *mockTxHook) { ctrl := gomock.NewController(t) - defer ctrl.Finish() mockMgr := initMockResourceManager(t, ctrl) _ = mockMgr @@ -44,9 +44,7 @@ func TestATConn_ExecContext(t *testing.T) { t.Fatal(err) } - defer db.Close() - - _ = initMockAtConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector { + _ = initMockAtConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) driver.Connector { mockTx := mock.NewMockTestDriverTx(ctrl) mockTx.EXPECT().Commit().AnyTimes().Return(nil) mockTx.EXPECT().Rollback().AnyTimes().Return(nil) @@ -67,20 +65,31 @@ func TestATConn_ExecContext(t *testing.T) { exec.CleanCommonHook() CleanTxHooks() - exec.RegisCommonHook(mi) + exec.RegisterCommonHook(mi) RegisterTxHook(ti) + return ctrl, db, mi, ti +} + +func TestATConn_ExecContext(t *testing.T) { + ctrl, db, mi, ti := initAtConnTestResource(t) + defer func() { + ctrl.Finish() + db.Close() + CleanTxHooks() + }() + t.Run("have xid", func(t *testing.T) { ctx := tm.InitSeataContext(context.Background()) tm.SetXID(ctx, uuid.New().String()) t.Logf("set xid=%s", tm.GetXID(ctx)) - before := func(_ context.Context, execCtx *types.ExecContext) { + beforeHook := func(_ context.Context, execCtx *types.ExecContext) { t.Logf("on exec xid=%s", execCtx.TxCtx.XaID) assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID) assert.Equal(t, types.ATMode, execCtx.TxCtx.TransType) } - mi.before = before + mi.before = beforeHook var comitCnt int32 beforeCommit := func(tx *Tx) { @@ -125,3 +134,92 @@ func TestATConn_ExecContext(t *testing.T) { assert.Equal(t, int32(0), atomic.LoadInt32(&comitCnt)) }) } + +func TestATConn_BeginTx(t *testing.T) { + ctrl, db, mi, ti := initAtConnTestResource(t) + defer func() { + ctrl.Finish() + db.Close() + CleanTxHooks() + }() + + t.Run("tx-local", func(t *testing.T) { + tx, err := db.Begin() + assert.NoError(t, err) + + mi.before = func(_ context.Context, execCtx *types.ExecContext) { + assert.Equal(t, "", execCtx.TxCtx.XaID) + assert.Equal(t, types.Local, execCtx.TxCtx.TransType) + } + + var comitCnt int32 + ti.beforeCommit = func(tx *Tx) { + atomic.AddInt32(&comitCnt, 1) + } + + _, err = tx.ExecContext(context.Background(), "SELECT * FROM user") + assert.NoError(t, err) + + _, err = tx.ExecContext(tm.InitSeataContext(context.Background()), "SELECT * FROM user") + assert.NoError(t, err) + + err = tx.Commit() + assert.NoError(t, err) + + assert.Equal(t, int32(1), atomic.LoadInt32(&comitCnt)) + }) + + t.Run("tx-local-context", func(t *testing.T) { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) + assert.NoError(t, err) + + mi.before = func(_ context.Context, execCtx *types.ExecContext) { + assert.Equal(t, "", execCtx.TxCtx.XaID) + assert.Equal(t, types.Local, execCtx.TxCtx.TransType) + } + + var comitCnt int32 + ti.beforeCommit = func(tx *Tx) { + atomic.AddInt32(&comitCnt, 1) + } + + _, err = tx.ExecContext(context.Background(), "SELECT * FROM user") + assert.NoError(t, err) + + _, err = tx.ExecContext(tm.InitSeataContext(context.Background()), "SELECT * FROM user") + assert.NoError(t, err) + + err = tx.Commit() + assert.NoError(t, err) + + assert.Equal(t, int32(1), atomic.LoadInt32(&comitCnt)) + }) + + t.Run("tx-at-context", func(t *testing.T) { + ctx := tm.InitSeataContext(context.Background()) + tm.SetXID(ctx, uuid.NewString()) + tx, err := db.BeginTx(ctx, &sql.TxOptions{}) + assert.NoError(t, err) + + mi.before = func(_ context.Context, execCtx *types.ExecContext) { + assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID) + assert.Equal(t, types.ATMode, execCtx.TxCtx.TransType) + } + + var comitCnt int32 + ti.beforeCommit = func(tx *Tx) { + atomic.AddInt32(&comitCnt, 1) + } + + _, err = tx.ExecContext(context.Background(), "SELECT * FROM user") + assert.NoError(t, err) + + _, err = tx.ExecContext(context.Background(), "SELECT * FROM user") + assert.NoError(t, err) + + err = tx.Commit() + assert.NoError(t, err) + + assert.Equal(t, int32(1), atomic.LoadInt32(&comitCnt)) + }) +} diff --git a/pkg/datasource/sql/conn_test.go b/pkg/datasource/sql/conn_test.go new file mode 100644 index 000000000..9ba23d448 --- /dev/null +++ b/pkg/datasource/sql/conn_test.go @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sql + +import ( + "reflect" + "testing" + + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/exec/xa" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/util/reflectx" + "github.com/stretchr/testify/assert" +) + +func TestConn_BuildATExecutor(t *testing.T) { + executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.ATMode, "SELECT * FROM user") + + assert.NoError(t, err) + _, ok := executor.(*exec.BaseExecutor) + assert.True(t, ok, "need base executor") +} + +func TestConn_BuildXAExecutor(t *testing.T) { + executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.XAMode, "SELECT * FROM user") + + assert.NoError(t, err) + val, ok := executor.(*exec.BaseExecutor) + assert.True(t, ok, "need base executor") + + v := reflect.ValueOf(val) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + field := v.FieldByName("ex") + + fieldVal := reflectx.GetUnexportedField(field) + + _, ok = fieldVal.(*xa.XAExecutor) + assert.True(t, ok, "need xa executor") +} diff --git a/pkg/datasource/sql/conn_xa.go b/pkg/datasource/sql/conn_xa.go index 7078e4c8a..1924bd654 100644 --- a/pkg/datasource/sql/conn_xa.go +++ b/pkg/datasource/sql/conn_xa.go @@ -25,15 +25,28 @@ import ( "github.com/seata/seata-go/pkg/tm" ) +// XAConn Database connection proxy object under XA transaction model +// Conn is assumed to be stateful. type XAConn struct { *Conn } +// QueryContext +func (c *XAConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if c.createOnceTxContext(ctx) { + defer func() { + c.txCtx = types.NewTxCtx() + }() + } + + return c.Conn.QueryContext(ctx, query, args) +} + // PrepareContext func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - if c.createTxCtxIfAbsent(ctx) { + if c.createOnceTxContext(ctx) { defer func() { - c.txCtx = nil + c.txCtx = types.NewTxCtx() }() } @@ -42,9 +55,9 @@ func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, // ExecContext func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - if c.createTxCtxIfAbsent(ctx) { + if c.createOnceTxContext(ctx) { defer func() { - c.txCtx = nil + c.txCtx = types.NewTxCtx() }() } @@ -53,6 +66,8 @@ func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.Na // BeginTx func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + c.autoCommit = false + c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType c.txCtx.TxOpt = opts @@ -62,24 +77,22 @@ func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, c.txCtx.XaID = tm.GetXID(ctx) } - return c.Conn.BeginTx(ctx, opts) + tx, err := c.Conn.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + + return &XATx{tx: tx.(*Tx)}, nil } -func (c *XAConn) createTxCtxIfAbsent(ctx context.Context) bool { - var onceTx bool +func (c *XAConn) createOnceTxContext(ctx context.Context) bool { + onceTx := IsGlobalTx(ctx) && c.autoCommit - if IsGlobalTx(ctx) && c.txCtx == nil { + if onceTx { c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType c.txCtx.XaID = tm.GetXID(ctx) c.txCtx.TransType = types.XAMode - c.autoCommit = true - onceTx = true - } - - if c.txCtx == nil { - c.txCtx = types.NewTxCtx() - onceTx = true } return onceTx diff --git a/pkg/datasource/sql/conn_xa_test.go b/pkg/datasource/sql/conn_xa_test.go index 53480e711..0140d0b1f 100644 --- a/pkg/datasource/sql/conn_xa_test.go +++ b/pkg/datasource/sql/conn_xa_test.go @@ -84,9 +84,8 @@ func baseMoclConn(mockConn *mock.MockTestDriverConn) { mockConn.EXPECT().Close().AnyTimes().Return(nil) } -func TestXAConn_ExecContext(t *testing.T) { +func initXAConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQLInterceptor, *mockTxHook) { ctrl := gomock.NewController(t) - defer ctrl.Finish() mockMgr := initMockResourceManager(t, ctrl) _ = mockMgr @@ -96,9 +95,7 @@ func TestXAConn_ExecContext(t *testing.T) { t.Fatal(err) } - defer db.Close() - - _ = initMockXaConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector { + _ = initMockXaConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) driver.Connector { mockTx := mock.NewMockTestDriverTx(ctrl) mockTx.EXPECT().Commit().AnyTimes().Return(nil) mockTx.EXPECT().Rollback().AnyTimes().Return(nil) @@ -118,9 +115,21 @@ func TestXAConn_ExecContext(t *testing.T) { exec.CleanCommonHook() CleanTxHooks() - exec.RegisCommonHook(mi) + exec.RegisterCommonHook(mi) RegisterTxHook(ti) + return ctrl, db, mi, ti +} + +func TestXAConn_ExecContext(t *testing.T) { + + ctrl, db, mi, ti := initXAConnTestResource(t) + defer func() { + ctrl.Finish() + db.Close() + CleanTxHooks() + }() + t.Run("have xid", func(t *testing.T) { ctx := tm.InitSeataContext(context.Background()) tm.SetXID(ctx, uuid.New().String()) @@ -178,3 +187,93 @@ func TestXAConn_ExecContext(t *testing.T) { assert.Equal(t, int32(0), atomic.LoadInt32(&comitCnt)) }) } + +func TestXAConn_BeginTx(t *testing.T) { + ctrl, db, mi, ti := initXAConnTestResource(t) + defer func() { + CleanTxHooks() + db.Close() + ctrl.Finish() + }() + + t.Run("tx-local", func(t *testing.T) { + tx, err := db.Begin() + assert.NoError(t, err) + + mi.before = func(_ context.Context, execCtx *types.ExecContext) { + assert.Equal(t, "", execCtx.TxCtx.XaID) + assert.Equal(t, types.Local, execCtx.TxCtx.TransType) + } + + var comitCnt int32 + ti.beforeCommit = func(tx *Tx) { + atomic.AddInt32(&comitCnt, 1) + } + + _, err = tx.ExecContext(context.Background(), "SELECT * FROM user") + assert.NoError(t, err) + + _, err = tx.ExecContext(tm.InitSeataContext(context.Background()), "SELECT * FROM user") + assert.NoError(t, err) + + err = tx.Commit() + assert.NoError(t, err) + + assert.Equal(t, int32(1), atomic.LoadInt32(&comitCnt)) + }) + + t.Run("tx-local-context", func(t *testing.T) { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) + assert.NoError(t, err) + + mi.before = func(_ context.Context, execCtx *types.ExecContext) { + assert.Equal(t, "", execCtx.TxCtx.XaID) + assert.Equal(t, types.Local, execCtx.TxCtx.TransType) + } + + var comitCnt int32 + ti.beforeCommit = func(tx *Tx) { + atomic.AddInt32(&comitCnt, 1) + } + + _, err = tx.ExecContext(context.Background(), "SELECT * FROM user") + assert.NoError(t, err) + + _, err = tx.ExecContext(tm.InitSeataContext(context.Background()), "SELECT * FROM user") + assert.NoError(t, err) + + err = tx.Commit() + assert.NoError(t, err) + + assert.Equal(t, int32(1), atomic.LoadInt32(&comitCnt)) + }) + + t.Run("tx-xa-context", func(t *testing.T) { + ctx := tm.InitSeataContext(context.Background()) + tm.SetXID(ctx, uuid.NewString()) + tx, err := db.BeginTx(ctx, &sql.TxOptions{}) + assert.NoError(t, err) + + mi.before = func(_ context.Context, execCtx *types.ExecContext) { + assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID) + assert.Equal(t, types.XAMode, execCtx.TxCtx.TransType) + } + + var comitCnt int32 + ti.beforeCommit = func(tx *Tx) { + atomic.AddInt32(&comitCnt, 1) + } + + _, err = tx.ExecContext(context.Background(), "SELECT * FROM user") + assert.NoError(t, err) + + _, err = tx.ExecContext(context.Background(), "SELECT * FROM user") + assert.NoError(t, err) + + err = tx.Commit() + assert.NoError(t, err) + + assert.Equal(t, int32(1), atomic.LoadInt32(&comitCnt)) + }) + +} diff --git a/pkg/datasource/sql/connector.go b/pkg/datasource/sql/connector.go index 146f50eb5..fa1588a47 100644 --- a/pkg/datasource/sql/connector.go +++ b/pkg/datasource/sql/connector.go @@ -113,7 +113,12 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - return &Conn{txType: types.Local, targetConn: conn, res: c.res}, nil + return &Conn{ + targetConn: conn, + res: c.res, + txCtx: types.NewTxCtx(), + autoCommit: true, + }, nil } // Driver returns the underlying Driver of the Connector, diff --git a/pkg/datasource/sql/connector_test.go b/pkg/datasource/sql/connector_test.go index 2f6bc4bea..35a19a837 100644 --- a/pkg/datasource/sql/connector_test.go +++ b/pkg/datasource/sql/connector_test.go @@ -31,9 +31,9 @@ import ( "github.com/stretchr/testify/assert" ) -type initConnectorFunc func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector +type initConnectorFunc func(t *testing.T, ctrl *gomock.Controller) driver.Connector -func initMockConnector(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector { +func initMockConnector(t *testing.T, ctrl *gomock.Controller) driver.Connector { mockConn := mock.NewMockTestDriverConn(ctrl) connector := mock.NewMockTestDriverConnector(ctrl) @@ -82,7 +82,7 @@ func Test_seataATConnector_Connect(t *testing.T) { atConn, ok := conn.(*ATConn) assert.True(t, ok, "need return seata at connection") - assert.True(t, atConn.txType == types.Local, "init need local tx") + assert.True(t, atConn.txCtx.TransType == types.Local, "init need local tx") } func initMockXaConnector(t *testing.T, ctrl *gomock.Controller, db *sql.DB, f initConnectorFunc) driver.Connector { @@ -126,5 +126,5 @@ func Test_seataXAConnector_Connect(t *testing.T) { xaConn, ok := conn.(*XAConn) assert.True(t, ok, "need return seata xa connection") - assert.True(t, xaConn.txType == types.Local, "init need local tx") + assert.True(t, xaConn.txCtx.TransType == types.Local, "init need local tx") } diff --git a/pkg/datasource/sql/driver.go b/pkg/datasource/sql/driver.go index 5b89f850d..fdecca077 100644 --- a/pkg/datasource/sql/driver.go +++ b/pkg/datasource/sql/driver.go @@ -22,7 +22,6 @@ import ( "database/sql" "database/sql/driver" "fmt" - "reflect" "strings" "github.com/go-sql-driver/mysql" @@ -30,7 +29,6 @@ import ( "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/util/log" - "github.com/seata/seata-go/pkg/util/reflectx" ) const ( @@ -103,24 +101,11 @@ func (d *seataDriver) Open(name string) (driver.Conn, error) { return nil, err } - v := reflect.ValueOf(conn) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - - field := v.FieldByName("connector") - proxy, err := d.OpenConnector(name) - if err != nil { - log.Errorf("open connector: %w", err) - return nil, err - } - - reflectx.SetUnexportedField(field, proxy) return conn, nil } func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) { - c = &dsnConnector{dsn: name, driver: d.target} + c = &dsnConnector{dsn: name, driver: d} if driverCtx, ok := d.target.(driver.DriverContext); ok { c, err = driverCtx.OpenConnector(name) if err != nil { diff --git a/pkg/datasource/sql/driver_test.go b/pkg/datasource/sql/driver_test.go index 9077dd315..5ded5e352 100644 --- a/pkg/datasource/sql/driver_test.go +++ b/pkg/datasource/sql/driver_test.go @@ -18,7 +18,9 @@ package sql import ( + "context" "database/sql" + "database/sql/driver" "reflect" "testing" @@ -40,6 +42,73 @@ func initMockResourceManager(t *testing.T, ctrl *gomock.Controller) *mock.MockDa return mockResourceMgr } +func Test_seataATDriver_Open(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMgr := initMockResourceManager(t, ctrl) + _ = mockMgr + + db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") + if err != nil { + t.Fatal(err) + } + + defer db.Close() + + _ = initMockAtConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) driver.Connector { + + v := reflect.ValueOf(db) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + field := v.FieldByName("connector") + fieldVal := reflectx.GetUnexportedField(field) + + driverVal, ok := fieldVal.(driver.Connector).Driver().(*seataATDriver) + assert.True(t, ok, "need seata at driver") + + vv := reflect.ValueOf(driverVal) + if vv.Kind() == reflect.Ptr { + vv = vv.Elem() + } + field = vv.FieldByName("target") + + mockDriver := mock.NewMockTestDriver(ctrl) + mockDriver.EXPECT().Open(gomock.Any()).Return(mock.NewMockTestDriverConn(ctrl), nil) + + reflectx.SetUnexportedField(field, mockDriver) + + connector := &dsnConnector{ + driver: driverVal, + } + return connector + }) + + conn, err := db.Conn(context.Background()) + assert.NoError(t, err) + + v := reflect.ValueOf(conn) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + field := v.FieldByName("dc") + fieldVal := reflectx.GetUnexportedField(field) + + vv := reflect.ValueOf(fieldVal) + if vv.Kind() == reflect.Ptr { + vv = vv.Elem() + } + + field = vv.FieldByName("ci") + fieldVal = reflectx.GetUnexportedField(field) + + _, ok := fieldVal.(*ATConn) + assert.True(t, ok, "need return seata at connection") +} + func Test_seataATDriver_OpenConnector(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/pkg/datasource/sql/exec/at/default.go b/pkg/datasource/sql/exec/at/default.go new file mode 100644 index 000000000..b99ce6517 --- /dev/null +++ b/pkg/datasource/sql/exec/at/default.go @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package at + +import ( + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/types" +) + +func init() { + exec.RegisterXAExecutor(types.DBTypeMySQL, func() exec.SQLExecutor { + return &ATExecutor{} + }) +} diff --git a/pkg/datasource/sql/exec/at/executor_at.go b/pkg/datasource/sql/exec/at/executor_at.go new file mode 100644 index 000000000..3c04141fd --- /dev/null +++ b/pkg/datasource/sql/exec/at/executor_at.go @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package at + +import ( + "context" + + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/types" +) + +type ATExecutor struct { + is []exec.SQLHook + ex exec.SQLExecutor +} + +// Interceptors +func (e *ATExecutor) Interceptors(interceptors []exec.SQLHook) { + e.is = interceptors +} + +// ExecWithNamedValue +func (e *ATExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { + for i := range e.is { + e.is[i].Before(ctx, execCtx) + } + + defer func() { + for i := range e.is { + e.is[i].After(ctx, execCtx) + } + }() + + if e.ex != nil { + return e.ex.ExecWithNamedValue(ctx, execCtx, f) + } + + return f(ctx, execCtx.Query, execCtx.NamedValues) +} + +// ExecWithValue +func (e *ATExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithValue) (types.ExecResult, error) { + for i := range e.is { + e.is[i].Before(ctx, execCtx) + } + + defer func() { + for i := range e.is { + e.is[i].After(ctx, execCtx) + } + }() + + if e.ex != nil { + return e.ex.ExecWithValue(ctx, execCtx, f) + } + + return f(ctx, execCtx.Query, execCtx.Values) +} diff --git a/pkg/datasource/sql/exec/executor.go b/pkg/datasource/sql/exec/executor.go index 781f03863..9392b6e7d 100644 --- a/pkg/datasource/sql/exec/executor.go +++ b/pkg/datasource/sql/exec/executor.go @@ -22,7 +22,6 @@ import ( "database/sql/driver" "github.com/seata/seata-go/pkg/datasource/sql/parser" - "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/undo" "github.com/seata/seata-go/pkg/datasource/sql/undo/builder" @@ -36,20 +35,31 @@ func init() { } // executorSolts -var executorSolts = make(map[types.DBType]map[types.ExecutorType]func() SQLExecutor) +var ( + executorSoltsAT = make(map[types.DBType]map[types.ExecutorType]func() SQLExecutor) + executorSoltsXA = make(map[types.DBType]func() SQLExecutor) +) -func RegisterExecutor(dt types.DBType, et types.ExecutorType, builder func() SQLExecutor) { - if _, ok := executorSolts[dt]; !ok { - executorSolts[dt] = make(map[types.ExecutorType]func() SQLExecutor) +// RegisterATExecutor +func RegisterATExecutor(dt types.DBType, et types.ExecutorType, builder func() SQLExecutor) { + if _, ok := executorSoltsAT[dt]; !ok { + executorSoltsAT[dt] = make(map[types.ExecutorType]func() SQLExecutor) } - val := executorSolts[dt] + val := executorSoltsAT[dt] val[et] = func() SQLExecutor { return &BaseExecutor{ex: builder()} } } +// RegisterXAExecutor +func RegisterXAExecutor(dt types.DBType, builder func() SQLExecutor) { + executorSoltsXA[dt] = func() SQLExecutor { + return &BaseExecutor{ex: builder()} + } +} + type ( CallbackWithNamedValue func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) @@ -57,7 +67,7 @@ type ( SQLExecutor interface { // Interceptors - interceptors(interceptors []SQLHook) + Interceptors(interceptors []SQLHook) // Exec ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) // Exec @@ -71,8 +81,8 @@ func BuildExecutor(dbType types.DBType, txType types.TransactionType, query stri hooks := make([]SQLHook, 0, 4) hooks = append(hooks, commonHook...) - e := &BaseExecutor{} - e.interceptors(hooks) + e := executorSoltsXA[dbType]() + e.Interceptors(hooks) return e, nil } @@ -85,11 +95,12 @@ func BuildExecutor(dbType types.DBType, txType types.TransactionType, query stri hooks = append(hooks, commonHook...) hooks = append(hooks, hookSolts[parseCtx.SQLType]...) - factories, ok := executorSolts[dbType] + factories, ok := executorSoltsAT[dbType] + if !ok { log.Debugf("%s not found executor factories, return default Executor", dbType.String()) e := &BaseExecutor{} - e.interceptors(hooks) + e.Interceptors(hooks) return e, nil } @@ -98,12 +109,12 @@ func BuildExecutor(dbType types.DBType, txType types.TransactionType, query stri log.Debugf("%s not found executor for %s, return default Executor", dbType.String(), parseCtx.ExecutorType) e := &BaseExecutor{} - e.interceptors(hooks) + e.Interceptors(hooks) return e, nil } executor := supplier() - executor.interceptors(hooks) + executor.Interceptors(hooks) return executor, nil } @@ -113,7 +124,7 @@ type BaseExecutor struct { } // Interceptors -func (e *BaseExecutor) interceptors(interceptors []SQLHook) { +func (e *BaseExecutor) Interceptors(interceptors []SQLHook) { e.is = interceptors } diff --git a/pkg/datasource/sql/exec/hook.go b/pkg/datasource/sql/exec/hook.go index 4f5c3e420..a9c47a486 100644 --- a/pkg/datasource/sql/exec/hook.go +++ b/pkg/datasource/sql/exec/hook.go @@ -30,7 +30,7 @@ var ( ) // RegisCommonHook not goroutine safe -func RegisCommonHook(hook SQLHook) { +func RegisterCommonHook(hook SQLHook) { commonHook = append(commonHook, hook) } @@ -38,8 +38,8 @@ func CleanCommonHook() { commonHook = make([]SQLHook, 0, 4) } -// RegisHook not goroutine safe -func RegisHook(hook SQLHook) { +// RegisterHook not goroutine safe +func RegisterHook(hook SQLHook) { _, ok := hookSolts[hook.Type()] if !ok { diff --git a/pkg/datasource/sql/exec/xa/default.go b/pkg/datasource/sql/exec/xa/default.go new file mode 100644 index 000000000..6472b74cf --- /dev/null +++ b/pkg/datasource/sql/exec/xa/default.go @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package xa + +import ( + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/types" +) + + +func init() { + exec.RegisterXAExecutor(types.DBTypeMySQL, func() exec.SQLExecutor { + return &XAExecutor{} + }) +} diff --git a/pkg/datasource/sql/exec/xa/executor_xa.go b/pkg/datasource/sql/exec/xa/executor_xa.go new file mode 100644 index 000000000..ca6c56dbf --- /dev/null +++ b/pkg/datasource/sql/exec/xa/executor_xa.go @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package xa + +import ( + "context" + + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/types" +) + +type XAExecutor struct { + is []exec.SQLHook + ex exec.SQLExecutor +} + +// Interceptors +func (e *XAExecutor) Interceptors(interceptors []exec.SQLHook) { + e.is = interceptors +} + +// ExecWithNamedValue +func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { + for i := range e.is { + e.is[i].Before(ctx, execCtx) + } + + defer func() { + for i := range e.is { + e.is[i].After(ctx, execCtx) + } + }() + + if e.ex != nil { + return e.ex.ExecWithNamedValue(ctx, execCtx, f) + } + + return f(ctx, execCtx.Query, execCtx.NamedValues) +} + +// ExecWithValue +func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithValue) (types.ExecResult, error) { + for i := range e.is { + e.is[i].Before(ctx, execCtx) + } + + defer func() { + for i := range e.is { + e.is[i].After(ctx, execCtx) + } + }() + + if e.ex != nil { + return e.ex.ExecWithValue(ctx, execCtx, f) + } + + return f(ctx, execCtx.Query, execCtx.Values) +} diff --git a/pkg/datasource/sql/hook/logger_hook.go b/pkg/datasource/sql/hook/logger_hook.go index 7d9ee8f82..2cc30d317 100644 --- a/pkg/datasource/sql/hook/logger_hook.go +++ b/pkg/datasource/sql/hook/logger_hook.go @@ -27,7 +27,7 @@ import ( ) func init() { - exec.RegisCommonHook(&loggerSQLHook{}) + exec.RegisterHook(&loggerSQLHook{}) } type loggerSQLHook struct{} @@ -44,6 +44,7 @@ func (h *loggerSQLHook) Before(ctx context.Context, execCtx *types.ExecContext) } fields := []zap.Field{ zap.String("tx-id", txID), + zap.String("xid", execCtx.TxCtx.XaID), zap.String("sql", execCtx.Query), } @@ -55,7 +56,7 @@ func (h *loggerSQLHook) Before(ctx context.Context, execCtx *types.ExecContext) fields = append(fields, zap.Any("values", execCtx.Values)) } - log.Debug("sql exec log", fields) + log.Info("sql exec log", fields) return nil } diff --git a/pkg/datasource/sql/hook/undo_log_hook.go b/pkg/datasource/sql/hook/undo_log_hook.go new file mode 100644 index 000000000..327bf7b93 --- /dev/null +++ b/pkg/datasource/sql/hook/undo_log_hook.go @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package hook + +import ( + "context" + + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/parser" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/undo" + "github.com/seata/seata-go/pkg/tm" +) + +func init() { + exec.RegisterHook(&undoLogSQLHook{}) +} + +type undoLogSQLHook struct { +} + +func (h *undoLogSQLHook) Type() types.SQLType { + return types.SQLTypeUnknown +} + +// Before +func (h *undoLogSQLHook) Before(ctx context.Context, execCtx *types.ExecContext) error { + if !tm.IsTransactionOpened(ctx) { + return nil + } + + pc, err := parser.DoParser(execCtx.Query) + if err != nil { + return err + } + if !pc.HasValidStmt() { + return nil + } + + builder := undo.GetUndologBuilder(pc.ExecutorType) + if builder == nil { + return nil + } + recordImage, err := builder.BeforeImage(ctx, execCtx) + if err != nil { + return err + } + execCtx.TxCtx.RoundImages.AppendBeofreImages(recordImage) + return nil +} + +// After +func (h *undoLogSQLHook) After(ctx context.Context, execCtx *types.ExecContext) error { + if !tm.IsTransactionOpened(ctx) { + return nil + } + return nil +} diff --git a/pkg/datasource/sql/mock/README.md b/pkg/datasource/sql/mock/README.md new file mode 100644 index 000000000..554aa1cb8 --- /dev/null +++ b/pkg/datasource/sql/mock/README.md @@ -0,0 +1,4 @@ +```bash +mockgen -source=test_driver.go -destination=./mock_driver.go -package=mock +mockgen -source=../datasource/datasource_manager.go -destination=./mock_datasource_manager.go -package=mock +``` \ No newline at end of file diff --git a/pkg/datasource/sql/mock/mock_datasource_manager.go b/pkg/datasource/sql/mock/mock_datasource_manager.go index cf4c42cb4..16ea98b8a 100644 --- a/pkg/datasource/sql/mock/mock_datasource_manager.go +++ b/pkg/datasource/sql/mock/mock_datasource_manager.go @@ -1,22 +1,5 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - // Code generated by MockGen. DO NOT EDIT. -// Source: datasource_manager.go +// Source: ../datasource/datasource_manager.go // Package mock is a generated GoMock package. package mock @@ -34,72 +17,74 @@ import ( rm "github.com/seata/seata-go/pkg/rm" ) -// MockDataSourceManager is a mock of DataSourceManager interface +// MockDataSourceManager is a mock of DataSourceManager interface. type MockDataSourceManager struct { ctrl *gomock.Controller recorder *MockDataSourceManagerMockRecorder } -// MockDataSourceManagerMockRecorder is the mock recorder for MockDataSourceManager +// MockDataSourceManagerMockRecorder is the mock recorder for MockDataSourceManager. type MockDataSourceManagerMockRecorder struct { mock *MockDataSourceManager } -// NewMockDataSourceManager creates a new mock instance +// NewMockDataSourceManager creates a new mock instance. func NewMockDataSourceManager(ctrl *gomock.Controller) *MockDataSourceManager { mock := &MockDataSourceManager{ctrl: ctrl} mock.recorder = &MockDataSourceManagerMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockDataSourceManager) EXPECT() *MockDataSourceManagerMockRecorder { return m.recorder } -// RegisterResource mocks base method -func (m *MockDataSourceManager) RegisterResource(resource rm.Resource) error { +// BranchCommit mocks base method. +func (m *MockDataSourceManager) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RegisterResource", resource) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "BranchCommit", ctx, req) + ret0, _ := ret[0].(branch.BranchStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// RegisterResource indicates an expected call of RegisterResource -func (mr *MockDataSourceManagerMockRecorder) RegisterResource(resource interface{}) *gomock.Call { +// BranchCommit indicates an expected call of BranchCommit. +func (mr *MockDataSourceManagerMockRecorder) BranchCommit(ctx, req interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterResource", reflect.TypeOf((*MockDataSourceManager)(nil).RegisterResource), resource) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchCommit", reflect.TypeOf((*MockDataSourceManager)(nil).BranchCommit), ctx, req) } -// UnregisterResource mocks base method -func (m *MockDataSourceManager) UnregisterResource(resource rm.Resource) error { +// BranchRegister mocks base method. +func (m *MockDataSourceManager) BranchRegister(ctx context.Context, clientId string, req message.BranchRegisterRequest) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UnregisterResource", resource) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "BranchRegister", ctx, clientId, req) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// UnregisterResource indicates an expected call of UnregisterResource -func (mr *MockDataSourceManagerMockRecorder) UnregisterResource(resource interface{}) *gomock.Call { +// BranchRegister indicates an expected call of BranchRegister. +func (mr *MockDataSourceManagerMockRecorder) BranchRegister(ctx, clientId, req interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterResource", reflect.TypeOf((*MockDataSourceManager)(nil).UnregisterResource), resource) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRegister", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRegister), ctx, clientId, req) } -// GetManagedResources mocks base method -func (m *MockDataSourceManager) GetManagedResources() map[string]rm.Resource { +// BranchReport mocks base method. +func (m *MockDataSourceManager) BranchReport(ctx context.Context, req message.BranchReportRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetManagedResources") - ret0, _ := ret[0].(map[string]rm.Resource) + ret := m.ctrl.Call(m, "BranchReport", ctx, req) + ret0, _ := ret[0].(error) return ret0 } -// GetManagedResources indicates an expected call of GetManagedResources -func (mr *MockDataSourceManagerMockRecorder) GetManagedResources() *gomock.Call { +// BranchReport indicates an expected call of BranchReport. +func (mr *MockDataSourceManagerMockRecorder) BranchReport(ctx, req interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedResources", reflect.TypeOf((*MockDataSourceManager)(nil).GetManagedResources)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchReport", reflect.TypeOf((*MockDataSourceManager)(nil).BranchReport), ctx, req) } -// BranchRollback mocks base method +// BranchRollback mocks base method. func (m *MockDataSourceManager) BranchRollback(ctx context.Context, req message.BranchRollbackRequest) (branch.BranchStatus, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "BranchRollback", ctx, req) @@ -108,124 +93,122 @@ func (m *MockDataSourceManager) BranchRollback(ctx context.Context, req message. return ret0, ret1 } -// BranchRollback indicates an expected call of BranchRollback +// BranchRollback indicates an expected call of BranchRollback. func (mr *MockDataSourceManagerMockRecorder) BranchRollback(ctx, req interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRollback", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRollback), ctx, req) } -// BranchCommit mocks base method -func (m *MockDataSourceManager) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) { +// CreateTableMetaCache mocks base method. +func (m *MockDataSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (datasource.TableMetaCache, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BranchCommit", ctx, req) - ret0, _ := ret[0].(branch.BranchStatus) + ret := m.ctrl.Call(m, "CreateTableMetaCache", ctx, resID, dbType, db) + ret0, _ := ret[0].(datasource.TableMetaCache) ret1, _ := ret[1].(error) return ret0, ret1 } -// BranchCommit indicates an expected call of BranchCommit -func (mr *MockDataSourceManagerMockRecorder) BranchCommit(ctx, req interface{}) *gomock.Call { +// CreateTableMetaCache indicates an expected call of CreateTableMetaCache. +func (mr *MockDataSourceManagerMockRecorder) CreateTableMetaCache(ctx, resID, dbType, db interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchCommit", reflect.TypeOf((*MockDataSourceManager)(nil).BranchCommit), ctx, req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTableMetaCache", reflect.TypeOf((*MockDataSourceManager)(nil).CreateTableMetaCache), ctx, resID, dbType, db) } -// LockQuery mocks base method -func (m *MockDataSourceManager) LockQuery(ctx context.Context, req message.GlobalLockQueryRequest) (bool, error) { +// GetManagedResources mocks base method. +func (m *MockDataSourceManager) GetManagedResources() map[string]rm.Resource { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LockQuery", ctx, req) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "GetManagedResources") + ret0, _ := ret[0].(map[string]rm.Resource) + return ret0 } -// LockQuery indicates an expected call of LockQuery -func (mr *MockDataSourceManagerMockRecorder) LockQuery(ctx, req interface{}) *gomock.Call { +// GetManagedResources indicates an expected call of GetManagedResources. +func (mr *MockDataSourceManagerMockRecorder) GetManagedResources() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockQuery", reflect.TypeOf((*MockDataSourceManager)(nil).LockQuery), ctx, req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedResources", reflect.TypeOf((*MockDataSourceManager)(nil).GetManagedResources)) } -// BranchRegister mocks base method -func (m *MockDataSourceManager) BranchRegister(ctx context.Context, clientId string, req message.BranchRegisterRequest) (int64, error) { +// LockQuery mocks base method. +func (m *MockDataSourceManager) LockQuery(ctx context.Context, req message.GlobalLockQueryRequest) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BranchRegister", ctx, clientId, req) - ret0, _ := ret[0].(int64) + ret := m.ctrl.Call(m, "LockQuery", ctx, req) + ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } -// BranchRegister indicates an expected call of BranchRegister -func (mr *MockDataSourceManagerMockRecorder) BranchRegister(ctx, clientId, req interface{}) *gomock.Call { +// LockQuery indicates an expected call of LockQuery. +func (mr *MockDataSourceManagerMockRecorder) LockQuery(ctx, req interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRegister", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRegister), ctx, clientId, req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockQuery", reflect.TypeOf((*MockDataSourceManager)(nil).LockQuery), ctx, req) } -// BranchReport mocks base method -func (m *MockDataSourceManager) BranchReport(ctx context.Context, req message.BranchReportRequest) error { +// RegisterResource mocks base method. +func (m *MockDataSourceManager) RegisterResource(resource rm.Resource) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BranchReport", ctx, req) + ret := m.ctrl.Call(m, "RegisterResource", resource) ret0, _ := ret[0].(error) return ret0 } -// BranchReport indicates an expected call of BranchReport -func (mr *MockDataSourceManagerMockRecorder) BranchReport(ctx, req interface{}) *gomock.Call { +// RegisterResource indicates an expected call of RegisterResource. +func (mr *MockDataSourceManagerMockRecorder) RegisterResource(resource interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchReport", reflect.TypeOf((*MockDataSourceManager)(nil).BranchReport), ctx, req) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterResource", reflect.TypeOf((*MockDataSourceManager)(nil).RegisterResource), resource) } -// CreateTableMetaCache mocks base method -func (m *MockDataSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (datasource.TableMetaCache, error) { +// UnregisterResource mocks base method. +func (m *MockDataSourceManager) UnregisterResource(resource rm.Resource) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateTableMetaCache", ctx, resID, dbType, db) - ret0, _ := ret[0].(datasource.TableMetaCache) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "UnregisterResource", resource) + ret0, _ := ret[0].(error) + return ret0 } -// CreateTableMetaCache indicates an expected call of CreateTableMetaCache -func (mr *MockDataSourceManagerMockRecorder) CreateTableMetaCache(ctx, resID, dbType, db interface{}) *gomock.Call { +// UnregisterResource indicates an expected call of UnregisterResource. +func (mr *MockDataSourceManagerMockRecorder) UnregisterResource(resource interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTableMetaCache", reflect.TypeOf((*MockDataSourceManager)(nil).CreateTableMetaCache), ctx, resID, dbType, db) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterResource", reflect.TypeOf((*MockDataSourceManager)(nil).UnregisterResource), resource) } -// MockTableMetaCache is a mock of TableMetaCache interface +// MockTableMetaCache is a mock of TableMetaCache interface. type MockTableMetaCache struct { ctrl *gomock.Controller recorder *MockTableMetaCacheMockRecorder } -// MockTableMetaCacheMockRecorder is the mock recorder for MockTableMetaCache +// MockTableMetaCacheMockRecorder is the mock recorder for MockTableMetaCache. type MockTableMetaCacheMockRecorder struct { mock *MockTableMetaCache } -// NewMockTableMetaCache creates a new mock instance +// NewMockTableMetaCache creates a new mock instance. func NewMockTableMetaCache(ctrl *gomock.Controller) *MockTableMetaCache { mock := &MockTableMetaCache{ctrl: ctrl} mock.recorder = &MockTableMetaCacheMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockTableMetaCache) EXPECT() *MockTableMetaCacheMockRecorder { return m.recorder } -// Init mocks base method -func (m *MockTableMetaCache) Init(ctx context.Context, conn *sql.DB) error { +// Destroy mocks base method. +func (m *MockTableMetaCache) Destroy() error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Init", ctx, conn) + ret := m.ctrl.Call(m, "Destroy") ret0, _ := ret[0].(error) return ret0 } -// Init indicates an expected call of Init -func (mr *MockTableMetaCacheMockRecorder) Init(ctx, conn interface{}) *gomock.Call { +// Destroy indicates an expected call of Destroy. +func (mr *MockTableMetaCacheMockRecorder) Destroy() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockTableMetaCache)(nil).Init), ctx, conn) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockTableMetaCache)(nil).Destroy)) } -// GetTableMeta mocks base method +// GetTableMeta mocks base method. func (m *MockTableMetaCache) GetTableMeta(table string) (types.TableMeta, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetTableMeta", table) @@ -234,22 +217,22 @@ func (m *MockTableMetaCache) GetTableMeta(table string) (types.TableMeta, error) return ret0, ret1 } -// GetTableMeta indicates an expected call of GetTableMeta +// GetTableMeta indicates an expected call of GetTableMeta. func (mr *MockTableMetaCacheMockRecorder) GetTableMeta(table interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMeta", reflect.TypeOf((*MockTableMetaCache)(nil).GetTableMeta), table) } -// Destroy mocks base method -func (m *MockTableMetaCache) Destroy() error { +// Init mocks base method. +func (m *MockTableMetaCache) Init(ctx context.Context, conn *sql.DB) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Destroy") + ret := m.ctrl.Call(m, "Init", ctx, conn) ret0, _ := ret[0].(error) return ret0 } -// Destroy indicates an expected call of Destroy -func (mr *MockTableMetaCacheMockRecorder) Destroy() *gomock.Call { +// Init indicates an expected call of Init. +func (mr *MockTableMetaCacheMockRecorder) Init(ctx, conn interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockTableMetaCache)(nil).Destroy)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockTableMetaCache)(nil).Init), ctx, conn) } diff --git a/pkg/datasource/sql/mock/mock_driver.go b/pkg/datasource/sql/mock/mock_driver.go index eedc9d59d..08dcdbaa4 100644 --- a/pkg/datasource/sql/mock/mock_driver.go +++ b/pkg/datasource/sql/mock/mock_driver.go @@ -1,20 +1,3 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - // Code generated by MockGen. DO NOT EDIT. // Source: test_driver.go @@ -29,30 +12,30 @@ import ( gomock "github.com/golang/mock/gomock" ) -// MockTestDriverConnector is a mock of TestDriverConnector interface +// MockTestDriverConnector is a mock of TestDriverConnector interface. type MockTestDriverConnector struct { ctrl *gomock.Controller recorder *MockTestDriverConnectorMockRecorder } -// MockTestDriverConnectorMockRecorder is the mock recorder for MockTestDriverConnector +// MockTestDriverConnectorMockRecorder is the mock recorder for MockTestDriverConnector. type MockTestDriverConnectorMockRecorder struct { mock *MockTestDriverConnector } -// NewMockTestDriverConnector creates a new mock instance +// NewMockTestDriverConnector creates a new mock instance. func NewMockTestDriverConnector(ctrl *gomock.Controller) *MockTestDriverConnector { mock := &MockTestDriverConnector{ctrl: ctrl} mock.recorder = &MockTestDriverConnectorMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockTestDriverConnector) EXPECT() *MockTestDriverConnectorMockRecorder { return m.recorder } -// Connect mocks base method +// Connect mocks base method. func (m *MockTestDriverConnector) Connect(arg0 context.Context) (driver.Conn, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Connect", arg0) @@ -61,13 +44,13 @@ func (m *MockTestDriverConnector) Connect(arg0 context.Context) (driver.Conn, er return ret0, ret1 } -// Connect indicates an expected call of Connect +// Connect indicates an expected call of Connect. func (mr *MockTestDriverConnectorMockRecorder) Connect(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockTestDriverConnector)(nil).Connect), arg0) } -// Driver mocks base method +// Driver mocks base method. func (m *MockTestDriverConnector) Driver() driver.Driver { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Driver") @@ -75,51 +58,66 @@ func (m *MockTestDriverConnector) Driver() driver.Driver { return ret0 } -// Driver indicates an expected call of Driver +// Driver indicates an expected call of Driver. func (mr *MockTestDriverConnectorMockRecorder) Driver() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Driver", reflect.TypeOf((*MockTestDriverConnector)(nil).Driver)) } -// MockTestDriverConn is a mock of TestDriverConn interface +// MockTestDriverConn is a mock of TestDriverConn interface. type MockTestDriverConn struct { ctrl *gomock.Controller recorder *MockTestDriverConnMockRecorder } -// MockTestDriverConnMockRecorder is the mock recorder for MockTestDriverConn +// MockTestDriverConnMockRecorder is the mock recorder for MockTestDriverConn. type MockTestDriverConnMockRecorder struct { mock *MockTestDriverConn } -// NewMockTestDriverConn creates a new mock instance +// NewMockTestDriverConn creates a new mock instance. func NewMockTestDriverConn(ctrl *gomock.Controller) *MockTestDriverConn { mock := &MockTestDriverConn{ctrl: ctrl} mock.recorder = &MockTestDriverConnMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockTestDriverConn) EXPECT() *MockTestDriverConnMockRecorder { return m.recorder } -// Prepare mocks base method -func (m *MockTestDriverConn) Prepare(query string) (driver.Stmt, error) { +// Begin mocks base method. +func (m *MockTestDriverConn) Begin() (driver.Tx, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Prepare", query) - ret0, _ := ret[0].(driver.Stmt) + ret := m.ctrl.Call(m, "Begin") + ret0, _ := ret[0].(driver.Tx) ret1, _ := ret[1].(error) return ret0, ret1 } -// Prepare indicates an expected call of Prepare -func (mr *MockTestDriverConnMockRecorder) Prepare(query interface{}) *gomock.Call { +// Begin indicates an expected call of Begin. +func (mr *MockTestDriverConnMockRecorder) Begin() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockTestDriverConn)(nil).Prepare), query) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTestDriverConn)(nil).Begin)) +} + +// BeginTx mocks base method. +func (m *MockTestDriverConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", ctx, opts) + ret0, _ := ret[0].(driver.Tx) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockTestDriverConnMockRecorder) BeginTx(ctx, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockTestDriverConn)(nil).BeginTx), ctx, opts) } -// Close mocks base method +// Close mocks base method. func (m *MockTestDriverConn) Close() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Close") @@ -127,43 +125,43 @@ func (m *MockTestDriverConn) Close() error { return ret0 } -// Close indicates an expected call of Close +// Close indicates an expected call of Close. func (mr *MockTestDriverConnMockRecorder) Close() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTestDriverConn)(nil).Close)) } -// Begin mocks base method -func (m *MockTestDriverConn) Begin() (driver.Tx, error) { +// Exec mocks base method. +func (m *MockTestDriverConn) Exec(query string, args []driver.Value) (driver.Result, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Begin") - ret0, _ := ret[0].(driver.Tx) + ret := m.ctrl.Call(m, "Exec", query, args) + ret0, _ := ret[0].(driver.Result) ret1, _ := ret[1].(error) return ret0, ret1 } -// Begin indicates an expected call of Begin -func (mr *MockTestDriverConnMockRecorder) Begin() *gomock.Call { +// Exec indicates an expected call of Exec. +func (mr *MockTestDriverConnMockRecorder) Exec(query, args interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTestDriverConn)(nil).Begin)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTestDriverConn)(nil).Exec), query, args) } -// BeginTx mocks base method -func (m *MockTestDriverConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { +// ExecContext mocks base method. +func (m *MockTestDriverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BeginTx", ctx, opts) - ret0, _ := ret[0].(driver.Tx) + ret := m.ctrl.Call(m, "ExecContext", ctx, query, args) + ret0, _ := ret[0].(driver.Result) ret1, _ := ret[1].(error) return ret0, ret1 } -// BeginTx indicates an expected call of BeginTx -func (mr *MockTestDriverConnMockRecorder) BeginTx(ctx, opts interface{}) *gomock.Call { +// ExecContext indicates an expected call of ExecContext. +func (mr *MockTestDriverConnMockRecorder) ExecContext(ctx, query, args interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockTestDriverConn)(nil).BeginTx), ctx, opts) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockTestDriverConn)(nil).ExecContext), ctx, query, args) } -// Ping mocks base method +// Ping mocks base method. func (m *MockTestDriverConn) Ping(ctx context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Ping", ctx) @@ -171,13 +169,28 @@ func (m *MockTestDriverConn) Ping(ctx context.Context) error { return ret0 } -// Ping indicates an expected call of Ping +// Ping indicates an expected call of Ping. func (mr *MockTestDriverConnMockRecorder) Ping(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockTestDriverConn)(nil).Ping), ctx) } -// PrepareContext mocks base method +// Prepare mocks base method. +func (m *MockTestDriverConn) Prepare(query string) (driver.Stmt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", query) + ret0, _ := ret[0].(driver.Stmt) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockTestDriverConnMockRecorder) Prepare(query interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockTestDriverConn)(nil).Prepare), query) +} + +// PrepareContext mocks base method. func (m *MockTestDriverConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PrepareContext", ctx, query) @@ -186,13 +199,13 @@ func (m *MockTestDriverConn) PrepareContext(ctx context.Context, query string) ( return ret0, ret1 } -// PrepareContext indicates an expected call of PrepareContext +// PrepareContext indicates an expected call of PrepareContext. func (mr *MockTestDriverConnMockRecorder) PrepareContext(ctx, query interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrepareContext", reflect.TypeOf((*MockTestDriverConn)(nil).PrepareContext), ctx, query) } -// Query mocks base method +// Query mocks base method. func (m *MockTestDriverConn) Query(query string, args []driver.Value) (driver.Rows, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Query", query, args) @@ -201,13 +214,13 @@ func (m *MockTestDriverConn) Query(query string, args []driver.Value) (driver.Ro return ret0, ret1 } -// Query indicates an expected call of Query +// Query indicates an expected call of Query. func (mr *MockTestDriverConnMockRecorder) Query(query, args interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTestDriverConn)(nil).Query), query, args) } -// QueryContext mocks base method +// QueryContext mocks base method. func (m *MockTestDriverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "QueryContext", ctx, query, args) @@ -216,43 +229,13 @@ func (m *MockTestDriverConn) QueryContext(ctx context.Context, query string, arg return ret0, ret1 } -// QueryContext indicates an expected call of QueryContext +// QueryContext indicates an expected call of QueryContext. func (mr *MockTestDriverConnMockRecorder) QueryContext(ctx, query, args interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockTestDriverConn)(nil).QueryContext), ctx, query, args) } -// Exec mocks base method -func (m *MockTestDriverConn) Exec(query string, args []driver.Value) (driver.Result, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Exec", query, args) - ret0, _ := ret[0].(driver.Result) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Exec indicates an expected call of Exec -func (mr *MockTestDriverConnMockRecorder) Exec(query, args interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTestDriverConn)(nil).Exec), query, args) -} - -// ExecContext mocks base method -func (m *MockTestDriverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ExecContext", ctx, query, args) - ret0, _ := ret[0].(driver.Result) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ExecContext indicates an expected call of ExecContext -func (mr *MockTestDriverConnMockRecorder) ExecContext(ctx, query, args interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockTestDriverConn)(nil).ExecContext), ctx, query, args) -} - -// ResetSession mocks base method +// ResetSession mocks base method. func (m *MockTestDriverConn) ResetSession(ctx context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ResetSession", ctx) @@ -260,36 +243,36 @@ func (m *MockTestDriverConn) ResetSession(ctx context.Context) error { return ret0 } -// ResetSession indicates an expected call of ResetSession +// ResetSession indicates an expected call of ResetSession. func (mr *MockTestDriverConnMockRecorder) ResetSession(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetSession", reflect.TypeOf((*MockTestDriverConn)(nil).ResetSession), ctx) } -// MockTestDriverStmt is a mock of TestDriverStmt interface +// MockTestDriverStmt is a mock of TestDriverStmt interface. type MockTestDriverStmt struct { ctrl *gomock.Controller recorder *MockTestDriverStmtMockRecorder } -// MockTestDriverStmtMockRecorder is the mock recorder for MockTestDriverStmt +// MockTestDriverStmtMockRecorder is the mock recorder for MockTestDriverStmt. type MockTestDriverStmtMockRecorder struct { mock *MockTestDriverStmt } -// NewMockTestDriverStmt creates a new mock instance +// NewMockTestDriverStmt creates a new mock instance. func NewMockTestDriverStmt(ctrl *gomock.Controller) *MockTestDriverStmt { mock := &MockTestDriverStmt{ctrl: ctrl} mock.recorder = &MockTestDriverStmtMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockTestDriverStmt) EXPECT() *MockTestDriverStmtMockRecorder { return m.recorder } -// Close mocks base method +// Close mocks base method. func (m *MockTestDriverStmt) Close() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Close") @@ -297,42 +280,57 @@ func (m *MockTestDriverStmt) Close() error { return ret0 } -// Close indicates an expected call of Close +// Close indicates an expected call of Close. func (mr *MockTestDriverStmtMockRecorder) Close() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTestDriverStmt)(nil).Close)) } -// NumInput mocks base method -func (m *MockTestDriverStmt) NumInput() int { +// Exec mocks base method. +func (m *MockTestDriverStmt) Exec(args []driver.Value) (driver.Result, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NumInput") - ret0, _ := ret[0].(int) - return ret0 + ret := m.ctrl.Call(m, "Exec", args) + ret0, _ := ret[0].(driver.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// NumInput indicates an expected call of NumInput -func (mr *MockTestDriverStmtMockRecorder) NumInput() *gomock.Call { +// Exec indicates an expected call of Exec. +func (mr *MockTestDriverStmtMockRecorder) Exec(args interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NumInput", reflect.TypeOf((*MockTestDriverStmt)(nil).NumInput)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTestDriverStmt)(nil).Exec), args) } -// Exec mocks base method -func (m *MockTestDriverStmt) Exec(args []driver.Value) (driver.Result, error) { +// ExecContext mocks base method. +func (m *MockTestDriverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Exec", args) + ret := m.ctrl.Call(m, "ExecContext", ctx, args) ret0, _ := ret[0].(driver.Result) ret1, _ := ret[1].(error) return ret0, ret1 } -// Exec indicates an expected call of Exec -func (mr *MockTestDriverStmtMockRecorder) Exec(args interface{}) *gomock.Call { +// ExecContext indicates an expected call of ExecContext. +func (mr *MockTestDriverStmtMockRecorder) ExecContext(ctx, args interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTestDriverStmt)(nil).Exec), args) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockTestDriverStmt)(nil).ExecContext), ctx, args) +} + +// NumInput mocks base method. +func (m *MockTestDriverStmt) NumInput() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NumInput") + ret0, _ := ret[0].(int) + return ret0 +} + +// NumInput indicates an expected call of NumInput. +func (mr *MockTestDriverStmtMockRecorder) NumInput() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NumInput", reflect.TypeOf((*MockTestDriverStmt)(nil).NumInput)) } -// Query mocks base method +// Query mocks base method. func (m *MockTestDriverStmt) Query(args []driver.Value) (driver.Rows, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Query", args) @@ -341,13 +339,13 @@ func (m *MockTestDriverStmt) Query(args []driver.Value) (driver.Rows, error) { return ret0, ret1 } -// Query indicates an expected call of Query +// Query indicates an expected call of Query. func (mr *MockTestDriverStmtMockRecorder) Query(args interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTestDriverStmt)(nil).Query), args) } -// QueryContext mocks base method +// QueryContext mocks base method. func (m *MockTestDriverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "QueryContext", ctx, args) @@ -356,51 +354,36 @@ func (m *MockTestDriverStmt) QueryContext(ctx context.Context, args []driver.Nam return ret0, ret1 } -// QueryContext indicates an expected call of QueryContext +// QueryContext indicates an expected call of QueryContext. func (mr *MockTestDriverStmtMockRecorder) QueryContext(ctx, args interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockTestDriverStmt)(nil).QueryContext), ctx, args) } -// ExecContext mocks base method -func (m *MockTestDriverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ExecContext", ctx, args) - ret0, _ := ret[0].(driver.Result) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ExecContext indicates an expected call of ExecContext -func (mr *MockTestDriverStmtMockRecorder) ExecContext(ctx, args interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockTestDriverStmt)(nil).ExecContext), ctx, args) -} - -// MockTestDriverTx is a mock of TestDriverTx interface +// MockTestDriverTx is a mock of TestDriverTx interface. type MockTestDriverTx struct { ctrl *gomock.Controller recorder *MockTestDriverTxMockRecorder } -// MockTestDriverTxMockRecorder is the mock recorder for MockTestDriverTx +// MockTestDriverTxMockRecorder is the mock recorder for MockTestDriverTx. type MockTestDriverTxMockRecorder struct { mock *MockTestDriverTx } -// NewMockTestDriverTx creates a new mock instance +// NewMockTestDriverTx creates a new mock instance. func NewMockTestDriverTx(ctrl *gomock.Controller) *MockTestDriverTx { mock := &MockTestDriverTx{ctrl: ctrl} mock.recorder = &MockTestDriverTxMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockTestDriverTx) EXPECT() *MockTestDriverTxMockRecorder { return m.recorder } -// Commit mocks base method +// Commit mocks base method. func (m *MockTestDriverTx) Commit() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Commit") @@ -408,13 +391,13 @@ func (m *MockTestDriverTx) Commit() error { return ret0 } -// Commit indicates an expected call of Commit +// Commit indicates an expected call of Commit. func (mr *MockTestDriverTxMockRecorder) Commit() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTestDriverTx)(nil).Commit)) } -// Rollback mocks base method +// Rollback mocks base method. func (m *MockTestDriverTx) Rollback() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Rollback") @@ -422,8 +405,111 @@ func (m *MockTestDriverTx) Rollback() error { return ret0 } -// Rollback indicates an expected call of Rollback +// Rollback indicates an expected call of Rollback. func (mr *MockTestDriverTxMockRecorder) Rollback() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTestDriverTx)(nil).Rollback)) } + +// MockTestDriverRows is a mock of TestDriverRows interface. +type MockTestDriverRows struct { + ctrl *gomock.Controller + recorder *MockTestDriverRowsMockRecorder +} + +// MockTestDriverRowsMockRecorder is the mock recorder for MockTestDriverRows. +type MockTestDriverRowsMockRecorder struct { + mock *MockTestDriverRows +} + +// NewMockTestDriverRows creates a new mock instance. +func NewMockTestDriverRows(ctrl *gomock.Controller) *MockTestDriverRows { + mock := &MockTestDriverRows{ctrl: ctrl} + mock.recorder = &MockTestDriverRowsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTestDriverRows) EXPECT() *MockTestDriverRowsMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockTestDriverRows) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockTestDriverRowsMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTestDriverRows)(nil).Close)) +} + +// Columns mocks base method. +func (m *MockTestDriverRows) Columns() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Columns") + ret0, _ := ret[0].([]string) + return ret0 +} + +// Columns indicates an expected call of Columns. +func (mr *MockTestDriverRowsMockRecorder) Columns() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Columns", reflect.TypeOf((*MockTestDriverRows)(nil).Columns)) +} + +// Next mocks base method. +func (m *MockTestDriverRows) Next(dest []driver.Value) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next", dest) + ret0, _ := ret[0].(error) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockTestDriverRowsMockRecorder) Next(dest interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockTestDriverRows)(nil).Next), dest) +} + +// MockTestDriver is a mock of TestDriver interface. +type MockTestDriver struct { + ctrl *gomock.Controller + recorder *MockTestDriverMockRecorder +} + +// MockTestDriverMockRecorder is the mock recorder for MockTestDriver. +type MockTestDriverMockRecorder struct { + mock *MockTestDriver +} + +// NewMockTestDriver creates a new mock instance. +func NewMockTestDriver(ctrl *gomock.Controller) *MockTestDriver { + mock := &MockTestDriver{ctrl: ctrl} + mock.recorder = &MockTestDriverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTestDriver) EXPECT() *MockTestDriverMockRecorder { + return m.recorder +} + +// Open mocks base method. +func (m *MockTestDriver) Open(name string) (driver.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Open", name) + ret0, _ := ret[0].(driver.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Open indicates an expected call of Open. +func (mr *MockTestDriverMockRecorder) Open(name interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockTestDriver)(nil).Open), name) +} diff --git a/pkg/datasource/sql/mock/test_driver.go b/pkg/datasource/sql/mock/test_driver.go index 0d193f62d..855e2424b 100644 --- a/pkg/datasource/sql/mock/test_driver.go +++ b/pkg/datasource/sql/mock/test_driver.go @@ -44,3 +44,11 @@ type TestDriverStmt interface { type TestDriverTx interface { driver.Tx } + +type TestDriverRows interface { + driver.Rows +} + +type TestDriver interface { + driver.Driver +} diff --git a/pkg/datasource/sql/plugin.go b/pkg/datasource/sql/plugin.go index eab60f44f..4e8afa89c 100644 --- a/pkg/datasource/sql/plugin.go +++ b/pkg/datasource/sql/plugin.go @@ -23,4 +23,7 @@ import ( // mysql 相关插件 _ "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" _ "github.com/seata/seata-go/pkg/datasource/sql/undo/mysql" + + _ "github.com/seata/seata-go/pkg/datasource/sql/exec/at" + _ "github.com/seata/seata-go/pkg/datasource/sql/exec/xa" ) diff --git a/pkg/datasource/sql/sql_test.go b/pkg/datasource/sql/sql_test.go deleted file mode 100644 index db851ebb3..000000000 --- a/pkg/datasource/sql/sql_test.go +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package sql - -import ( - "context" - "database/sql" - "fmt" - "sync" - "testing" - - _ "github.com/go-sql-driver/mysql" - "github.com/seata/seata-go/pkg/client" - "github.com/seata/seata-go/pkg/util/log" -) - -var db *sql.DB - -func Test_SQLOpen(t *testing.T) { - client.Init() - t.SkipNow() - log.Info("begin test") - var err error - db, err = sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") - if err != nil { - t.Fatal(err) - } - - defer db.Close() - - sqlStmt := ` - create table if not exists foo (id integer not null primary key, name text); - delete from foo; - ` - _, err = db.Exec(sqlStmt) - if err != nil { - t.Fatal(err) - } - - wait := sync.WaitGroup{} - - txInvoker := func(prefix string, offset, total int) { - defer wait.Done() - - tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) - if err != nil { - t.Fatal(err) - } - - stmt, err := tx.Prepare("insert into foo(id, name) values(?, ?)") - if err != nil { - t.Fatal(err) - } - defer stmt.Close() - for i := 0; i < total; i++ { - _, err = stmt.Exec(i+offset, fmt.Sprintf("%s-%03d", prefix, i)) - if err != nil { - t.Fatal(err) - } - } - err = tx.Commit() - if err != nil { - t.Fatal(err) - } - } - - wait.Add(2) - - t.Parallel() - t.Run("", func(t *testing.T) { - txInvoker("seata-go-at-1", 0, 10) - }) - t.Run("", func(t *testing.T) { - txInvoker("seata-go-at-2", 20, 10) - }) - - wait.Wait() - queryMultiRow() -} - -func queryMultiRow() { - sqlStr := "select id, name from foo where id > ?" - rows, err := db.Query(sqlStr, 0) - if err != nil { - fmt.Printf("query failed, err:%v\n", err) - return - } - defer rows.Close() - - for rows.Next() { - var u user - err := rows.Scan(&u.id, &u.name) - if err != nil { - fmt.Printf("scan failed, err:%v\n", err) - return - } - fmt.Printf("id:%d username:%s password:%s\n", u.id, u.name, u.name) - } -} - -type user struct { - id int - name string -} diff --git a/pkg/datasource/sql/tx.go b/pkg/datasource/sql/tx.go index 8748f0cc1..d49742fad 100644 --- a/pkg/datasource/sql/tx.go +++ b/pkg/datasource/sql/tx.go @@ -22,14 +22,11 @@ import ( "database/sql/driver" "sync" - "github.com/seata/seata-go/pkg/datasource/sql/undo" - "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/protocol/message" "github.com/seata/seata-go/pkg/util/log" - "github.com/pkg/errors" "github.com/seata/seata-go/pkg/datasource/sql/types" ) @@ -107,10 +104,12 @@ type Tx struct { } // Commit do commit action -// case 1. no open global-transaction, just do local transaction commit -// case 2. not need flush undolog, is XA mode, do local transaction commit -// case 3. need run AT transaction func (tx *Tx) Commit() error { + tx.beforeCommit() + return tx.commitOnLocal() +} + +func (tx *Tx) beforeCommit() { if len(txHooks) != 0 { hl.RLock() defer hl.RUnlock() @@ -119,17 +118,6 @@ func (tx *Tx) Commit() error { txHooks[i].BeforeCommit(tx) } } - - if tx.ctx.TransType == types.Local { - return tx.commitOnLocal() - } - - // flush undo log if need, is XA mode - if tx.ctx.TransType == types.XAMode { - return tx.commitOnXA() - } - - return tx.commitOnAT() } func (tx *Tx) Rollback() error { @@ -142,14 +130,7 @@ func (tx *Tx) Rollback() error { } } - err := tx.target.Rollback() - if err != nil { - if tx.ctx.OpenGlobalTrsnaction() && tx.ctx.IsBranchRegistered() { - tx.report(false) - } - } - - return err + return tx.target.Rollback() } // init @@ -162,41 +143,6 @@ func (tx *Tx) commitOnLocal() error { return tx.target.Commit() } -// commitOnXA -func (tx *Tx) commitOnXA() error { - return nil -} - -// commitOnAT -func (tx *Tx) commitOnAT() error { - // if TX-Mode is AT, run regis this transaction branch - if err := tx.register(tx.ctx); err != nil { - return err - } - - undoLogMgr, err := undo.GetUndoLogManager(tx.ctx.DBType) - if err != nil { - return err - } - - if err := undoLogMgr.FlushUndoLog(tx.ctx, tx.conn.targetConn); err != nil { - if rerr := tx.report(false); rerr != nil { - return errors.WithStack(rerr) - } - return errors.WithStack(err) - } - - if err := tx.commitOnLocal(); err != nil { - if rerr := tx.report(false); rerr != nil { - return errors.WithStack(rerr) - } - return errors.WithStack(err) - } - - tx.report(true) - return nil -} - // register func (tx *Tx) register(ctx *types.TransactionContext) error { if !ctx.HasUndoLog() || !ctx.HasLockKey() { diff --git a/pkg/datasource/sql/tx_at.go b/pkg/datasource/sql/tx_at.go new file mode 100644 index 000000000..b982a599b --- /dev/null +++ b/pkg/datasource/sql/tx_at.go @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sql + +import ( + "github.com/seata/seata-go/pkg/datasource/sql/undo" + + "github.com/pkg/errors" +) + +// ATTx +type ATTx struct { + tx *Tx +} + +// Commit do commit action +// case 1. no open global-transaction, just do local transaction commit +// case 2. not need flush undolog, is XA mode, do local transaction commit +// case 3. need run AT transaction +func (tx *ATTx) Commit() error { + tx.tx.beforeCommit() + return tx.commitOnAT() +} + +func (tx *ATTx) Rollback() error { + err := tx.tx.Rollback() + if err != nil { + + originTx := tx.tx + + if originTx.ctx.OpenGlobalTrsnaction() && originTx.ctx.IsBranchRegistered() { + originTx.report(false) + } + } + + return err +} + +// commitOnAT +func (tx *ATTx) commitOnAT() error { + originTx := tx.tx + + if err := originTx.register(originTx.ctx); err != nil { + return err + } + + undoLogMgr, err := undo.GetUndoLogManager(originTx.ctx.DBType) + if err != nil { + return err + } + + if err := undoLogMgr.FlushUndoLog(originTx.ctx, originTx.conn.targetConn); err != nil { + if rerr := originTx.report(false); rerr != nil { + return errors.WithStack(rerr) + } + return errors.WithStack(err) + } + + if err := originTx.commitOnLocal(); err != nil { + if rerr := originTx.report(false); rerr != nil { + return errors.WithStack(rerr) + } + return errors.WithStack(err) + } + + originTx.report(true) + return nil +} diff --git a/pkg/datasource/sql/tx_xa.go b/pkg/datasource/sql/tx_xa.go new file mode 100644 index 000000000..85aa73c54 --- /dev/null +++ b/pkg/datasource/sql/tx_xa.go @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sql + +// XATx +type XATx struct { + tx *Tx +} + +// Commit do commit action +// case 1. no open global-transaction, just do local transaction commit +// case 2. not need flush undolog, is XA mode, do local transaction commit +// case 3. need run AT transaction +func (tx *XATx) Commit() error { + tx.tx.beforeCommit() + return tx.commitOnXA() +} + +func (tx *XATx) Rollback() error { + err := tx.tx.Rollback() + if err != nil { + + originTx := tx.tx + + if originTx.ctx.OpenGlobalTrsnaction() && originTx.ctx.IsBranchRegistered() { + originTx.report(false) + } + } + + return err +} + +// commitOnXA +func (tx *XATx) commitOnXA() error { + return nil +}