Skip to content

Commit

Permalink
refactor:split xa and at logic
Browse files Browse the repository at this point in the history
  • Loading branch information
chuntaojun committed Oct 3, 2022
1 parent 0ce5bef commit e0e1113
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 174 deletions.
10 changes: 6 additions & 4 deletions pkg/datasource/sql/conn_at.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ import (
"github.com/seata/seata-go/pkg/tm"
)

// ATConn Database connection proxy object under XA transaction model
// Conn is assumed to be stateful.
type ATConn struct {
*Conn
}

func (c *ATConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if c.createTxCtxIfAbsent(ctx) {
if c.createOnceTxContext(ctx) {
defer func() {
c.txCtx = types.NewTxCtx()
}()
Expand All @@ -41,7 +43,7 @@ func (c *ATConn) PrepareContext(ctx context.Context, query string) (driver.Stmt,

// QueryContext
func (c *ATConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
if c.createTxCtxIfAbsent(ctx) {
if c.createOnceTxContext(ctx) {
defer func() {
c.txCtx = types.NewTxCtx()
}()
Expand All @@ -52,7 +54,7 @@ func (c *ATConn) QueryContext(ctx context.Context, query string, args []driver.N

// ExecContext
func (c *ATConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if c.createTxCtxIfAbsent(ctx) {
if c.createOnceTxContext(ctx) {
defer func() {
c.txCtx = types.NewTxCtx()
}()
Expand Down Expand Up @@ -82,7 +84,7 @@ func (c *ATConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx,
return &ATTx{tx: tx.(*Tx)}, nil
}

func (c *ATConn) createTxCtxIfAbsent(ctx context.Context) bool {
func (c *ATConn) createOnceTxContext(ctx context.Context) bool {
onceTx := IsGlobalTx(ctx) && c.autoCommit

if onceTx {
Expand Down
107 changes: 102 additions & 5 deletions pkg/datasource/sql/conn_at_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ import (
"github.com/stretchr/testify/assert"
)

func TestATConn_ExecContext(t *testing.T) {
func initAtConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQLInterceptor, *mockTxHook) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
_ = mockMgr
Expand All @@ -45,8 +44,6 @@ func TestATConn_ExecContext(t *testing.T) {
t.Fatal(err)
}

defer db.Close()

_ = initMockAtConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) driver.Connector {
mockTx := mock.NewMockTestDriverTx(ctrl)
mockTx.EXPECT().Commit().AnyTimes().Return(nil)
Expand All @@ -68,9 +65,20 @@ func TestATConn_ExecContext(t *testing.T) {

exec.CleanCommonHook()
CleanTxHooks()
exec.RegisCommonHook(mi)
exec.RegisterCommonHook(mi)
RegisterTxHook(ti)

return ctrl, db, mi, ti
}

func TestATConn_ExecContext(t *testing.T) {
ctrl, db, mi, ti := initAtConnTestResource(t)
defer func() {
ctrl.Finish()
db.Close()
CleanTxHooks()
}()

t.Run("have xid", func(t *testing.T) {
ctx := tm.InitSeataContext(context.Background())
tm.SetXID(ctx, uuid.New().String())
Expand Down Expand Up @@ -126,3 +134,92 @@ func TestATConn_ExecContext(t *testing.T) {
assert.Equal(t, int32(0), atomic.LoadInt32(&comitCnt))
})
}

func TestATConn_BeginTx(t *testing.T) {
ctrl, db, mi, ti := initAtConnTestResource(t)
defer func() {
ctrl.Finish()
db.Close()
CleanTxHooks()
}()

t.Run("tx-local", func(t *testing.T) {
tx, err := db.Begin()
assert.NoError(t, err)

mi.before = func(_ context.Context, execCtx *exec.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XaID)
assert.Equal(t, types.Local, execCtx.TxCtx.TransType)
}

var comitCnt int32
ti.beforeCommit = func(tx *Tx) {
atomic.AddInt32(&comitCnt, 1)
}

_, err = tx.ExecContext(context.Background(), "SELECT * FROM user")
assert.NoError(t, err)

_, err = tx.ExecContext(tm.InitSeataContext(context.Background()), "SELECT * FROM user")
assert.NoError(t, err)

err = tx.Commit()
assert.NoError(t, err)

assert.Equal(t, int32(1), atomic.LoadInt32(&comitCnt))
})

t.Run("tx-local-context", func(t *testing.T) {
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
assert.NoError(t, err)

mi.before = func(_ context.Context, execCtx *exec.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XaID)
assert.Equal(t, types.Local, execCtx.TxCtx.TransType)
}

var comitCnt int32
ti.beforeCommit = func(tx *Tx) {
atomic.AddInt32(&comitCnt, 1)
}

_, err = tx.ExecContext(context.Background(), "SELECT * FROM user")
assert.NoError(t, err)

_, err = tx.ExecContext(tm.InitSeataContext(context.Background()), "SELECT * FROM user")
assert.NoError(t, err)

err = tx.Commit()
assert.NoError(t, err)

assert.Equal(t, int32(1), atomic.LoadInt32(&comitCnt))
})

t.Run("tx-at-context", func(t *testing.T) {
ctx := tm.InitSeataContext(context.Background())
tm.SetXID(ctx, uuid.NewString())
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
assert.NoError(t, err)

mi.before = func(_ context.Context, execCtx *exec.ExecContext) {
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID)
assert.Equal(t, types.ATMode, execCtx.TxCtx.TransType)
}

var comitCnt int32
ti.beforeCommit = func(tx *Tx) {
atomic.AddInt32(&comitCnt, 1)
}

_, err = tx.ExecContext(context.Background(), "SELECT * FROM user")
assert.NoError(t, err)

_, err = tx.ExecContext(context.Background(), "SELECT * FROM user")
assert.NoError(t, err)

err = tx.Commit()
assert.NoError(t, err)

assert.Equal(t, int32(1), atomic.LoadInt32(&comitCnt))
})
}
56 changes: 56 additions & 0 deletions pkg/datasource/sql/conn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sql

import (
"reflect"
"testing"

"github.com/seata/seata-go/pkg/common/reflectx"
"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/exec/xa"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/stretchr/testify/assert"
)

func TestConn_BuildATExecutor(t *testing.T) {
executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.ATMode, "SELECT * FROM user")

assert.NoError(t, err)
_, ok := executor.(*exec.BaseExecutor)
assert.True(t, ok, "need base executor")
}

func TestConn_BuildXAExecutor(t *testing.T) {
executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.XAMode, "SELECT * FROM user")

assert.NoError(t, err)
val, ok := executor.(*exec.BaseExecutor)
assert.True(t, ok, "need base executor")

v := reflect.ValueOf(val)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
field := v.FieldByName("ex")

fieldVal := reflectx.GetUnexportedField(field)

_, ok = fieldVal.(*xa.XAExecutor)
assert.True(t, ok, "need xa executor")
}
19 changes: 16 additions & 3 deletions pkg/datasource/sql/conn_xa.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,26 @@ import (
"github.com/seata/seata-go/pkg/tm"
)

// XAConn Database connection proxy object under XA transaction model
// Conn is assumed to be stateful.
type XAConn struct {
*Conn
}

// QueryContext
func (c *XAConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
if c.createOnceTxContext(ctx) {
defer func() {
c.txCtx = types.NewTxCtx()
}()
}

return c.Conn.QueryContext(ctx, query, args)
}

// PrepareContext
func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if c.createTxCtxIfAbsent(ctx) {
if c.createOnceTxContext(ctx) {
defer func() {
c.txCtx = types.NewTxCtx()
}()
Expand All @@ -42,7 +55,7 @@ func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt,

// ExecContext
func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if c.createTxCtxIfAbsent(ctx) {
if c.createOnceTxContext(ctx) {
defer func() {
c.txCtx = types.NewTxCtx()
}()
Expand Down Expand Up @@ -72,7 +85,7 @@ func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx,
return &XATx{tx: tx.(*Tx)}, nil
}

func (c *XAConn) createTxCtxIfAbsent(ctx context.Context) bool {
func (c *XAConn) createOnceTxContext(ctx context.Context) bool {
onceTx := IsGlobalTx(ctx) && c.autoCommit

if onceTx {
Expand Down
Loading

0 comments on commit e0e1113

Please sign in to comment.