Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix driver transaction behavior #84

Merged
merged 2 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 77 additions & 16 deletions driver/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"database/sql/driver"

"github.com/proullon/ramsql/engine/executor"
"github.com/proullon/ramsql/engine/log"
)

// Conn implements sql/driver Conn interface
Expand All @@ -29,7 +30,8 @@ import (
// https://pkg.go.dev/database/sql/driver#ConnPrepareContext
// https://pkg.go.dev/database/sql/driver#ConnBeginTx
type Conn struct {
e *executor.Engine
e *executor.Engine
tx *executor.Tx
}

func newConn(e *executor.Engine) *Conn {
Expand Down Expand Up @@ -83,6 +85,11 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) {
//
// Implemented for Conn interface
func (c *Conn) Close() error {
if c.tx != nil {
c.tx.Rollback()
c.tx = nil
}

return nil
}

Expand All @@ -92,27 +99,67 @@ func (c *Conn) Close() error {
//
// Implemented for Conn interface
func (c *Conn) Begin() (driver.Tx, error) {
return executor.NewTx(context.Background(), c.e, sql.TxOptions{})
tx, err := executor.NewTx(context.Background(), c.e, sql.TxOptions{})
if err != nil {
return nil, err
}
c.tx = tx
return c, nil
}

// BeginTx starts and returns a new transaction.
//
// Implemented for ConnBeginTx interface
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
o := sql.TxOptions{
Isolation: sql.IsolationLevel(opts.Isolation),
ReadOnly: opts.ReadOnly,
}
return executor.NewTx(ctx, c.e, o)
tx, err := executor.NewTx(ctx, c.e, o)
if err != nil {
return nil, err
}
c.tx = tx
return c, nil
}

func (c *Conn) Rollback() error {
if c.tx == nil {
return nil
}
err := c.tx.Rollback()
c.tx = nil
return err
}

func (c *Conn) Commit() error {
if c.tx == nil {
return nil
}
err := c.tx.Commit()
c.tx = nil
return err
}

// QueryContext is the sql package prefered way to run QUERY.
//
// Implemented for QueryerContext interface
func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
var err error
autocommit := false

tx, err := c.e.Begin()
if err != nil {
return nil, err
log.Info("Conn.QueryContext: %s", query)

tx := c.tx

if tx == nil {
autocommit = true
tx, err = c.e.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
}
defer tx.Rollback()

a := make([]executor.NamedValue, len(args))
for i, arg := range args {
Expand All @@ -126,9 +173,11 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return nil, err
}

err = tx.Commit()
if err != nil {
return nil, err
if autocommit {
err = tx.Commit()
if err != nil {
return nil, err
}
}

return newRows(cols, tuples), nil
Expand All @@ -138,9 +187,19 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam
//
// Implemented for ExecerContext interface
func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
tx, err := c.e.Begin()
if err != nil {
return nil, err
var err error
autocommit := false
log.Info("Conn.ExecContext: %s", query)

tx := c.tx

if tx == nil {
autocommit = true
tx, err = c.e.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
}

a := make([]executor.NamedValue, len(args))
Expand All @@ -156,9 +215,11 @@ func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.Name
return r, r.err
}

err = tx.Commit()
if err != nil {
return r, r.err
if autocommit {
err = tx.Commit()
if err != nil {
return r, r.err
}
}

return r, r.err
Expand Down
86 changes: 85 additions & 1 deletion driver/tx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package ramsql

import (
"database/sql"
"sync"
"testing"

"github.com/proullon/ramsql/engine/log"
)

func TestTransaction(t *testing.T) {
Expand Down Expand Up @@ -30,6 +33,19 @@ func TestTransaction(t *testing.T) {
}
}

db.SetMaxOpenConns(10)
var wg sync.WaitGroup

for i := 0; i < 15; i++ {
wg.Add(1)
go execTestTransactionQuery(t, db, &wg)
}

wg.Wait()
}

func execTestTransactionQuery(t *testing.T, db *sql.DB, wg *sync.WaitGroup) {

tx, err := db.Begin()
if err != nil {
t.Fatalf("Cannot create tx: %s", err)
Expand All @@ -50,7 +66,75 @@ func TestTransaction(t *testing.T) {
t.Fatalf("cannot commit tx: %s", err)
}

// Select count
wg.Done()
}

func TestTransactionRollback(t *testing.T) {
log.SetLevel(log.InfoLevel)
defer log.SetLevel(log.ErrorLevel)

db, err := sql.Open("ramsql", "TestTransactionRollback")
if err != nil {
t.Fatalf("sql.Open : Error : %s\n", err)
}
defer db.Close()

init := []string{
`CREATE TABLE account (id INT, email TEXT)`,
`INSERT INTO account (id, email) VALUES (1, '[email protected]')`,
`INSERT INTO account (id, email) VALUES (2, '[email protected]')`,
`CREATE TABLE champion (user_id INT, name TEXT)`,
`INSERT INTO champion (user_id, name) VALUES (1, 'zed')`,
`INSERT INTO champion (user_id, name) VALUES (2, 'lulu')`,
`INSERT INTO champion (user_id, name) VALUES (1, 'thresh')`,
`INSERT INTO champion (user_id, name) VALUES (1, 'lux')`,
}
for _, q := range init {
_, err = db.Exec(q)
if err != nil {
t.Fatalf("sql.Exec: Error: %s\n", err)
}
}

var count int
err = db.QueryRow("SELECT COUNT(user_id) FROM champion WHERE user_id = 1").Scan(&count)
if err != nil {
t.Fatalf("cannot query row in tx: %s\n", err)
}
if count != 3 {
t.Fatalf("expected COUNT(user_id)=3 row, got %d", count)
}

tx, err := db.Begin()
if err != nil {
t.Fatalf("cannot begin transaction: %s", err)
}

_, err = tx.Exec(`INSERT INTO champion (user_id, name) VALUES (1, 'new-champ')`)
if err != nil {
t.Fatalf("cannot insert within transaction: %s", err)
}

err = tx.QueryRow("SELECT COUNT(*) FROM champion WHERE user_id = 1").Scan(&count)
if err != nil {
t.Fatalf("cannot query row in tx: %s\n", err)
}
if count != 4 {
t.Fatalf("expected COUNT(user_id)=4 row within tx, got %d", count)
}

err = tx.Rollback()
if err != nil {
t.Fatalf("cannot rollback transaction: %s", err)
}

err = db.QueryRow("SELECT COUNT(user_id) FROM champion WHERE user_id = 1").Scan(&count)
if err != nil {
t.Fatalf("cannot query row in tx: %s\n", err)
}
if count != 3 {
t.Fatalf("expected COUNT(user_id)=3 row, got %d", count)
}
}

func TestCheckAttributes(t *testing.T) {
Expand Down
6 changes: 6 additions & 0 deletions engine/executor/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ func NewTx(ctx context.Context, e *Engine, opts sql.TxOptions) (*Tx, error) {
parser.DropToken: dropExecutor,
parser.GrantToken: grantExecutor,
}

log.Info("Begin(%p)", t.tx)
return t, nil
}

func (t *Tx) QueryContext(ctx context.Context, query string, args []NamedValue) ([]string, []*agnostic.Tuple, error) {
log.Info("QueryContext(%p, %s)", t.tx, query)

instructions, err := parser.ParseInstruction(query)
if err != nil {
Expand Down Expand Up @@ -88,17 +91,20 @@ func (t *Tx) QueryContext(ctx context.Context, query string, args []NamedValue)

// Commit the transaction on server
func (t *Tx) Commit() error {
log.Info("Commit(%p)", t.tx)
_, err := t.tx.Commit()
return err
}

// Rollback all changes
func (t *Tx) Rollback() error {
log.Info("Rollback(%p)", t.tx)
t.tx.Rollback()
return nil
}

func (t *Tx) ExecContext(ctx context.Context, query string, args []NamedValue) (int64, int64, error) {
log.Info("ExecContext(%p, %s)", t.tx, query)

instructions, err := parser.ParseInstruction(query)
if err != nil {
Expand Down