Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: enhance transaction functionality #1281

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Run `make lint` from the root path of this project to check code with golangci-lint.

run:
deadline: 6m
timeout: 5m

linters:
# Uncomment this line to run only the explicitly enabled linters
Expand Down
4 changes: 2 additions & 2 deletions pulsar/consumer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ func (pc *partitionConsumer) internalAckWithTxn(req *ackWithTxnRequest) {
req.err = newError(ConsumerClosed, "Failed to ack by closing or closed consumer")
return
}
if req.Transaction.state != TxnOpen {
pc.log.WithField("state", req.Transaction.state).Error("Failed to ack by a non-open transaction.")
if req.Transaction.state.Load() != int32(TxnOpen) {
pc.log.WithField("state", req.Transaction.state.Load()).Error("Failed to ack by a non-open transaction.")
req.err = newError(InvalidStatus, "Failed to ack by a non-open transaction.")
return
}
Expand Down
4 changes: 2 additions & 2 deletions pulsar/producer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -1135,8 +1135,8 @@ func (p *partitionProducer) prepareTransaction(sr *sendRequest) error {
}

txn := (sr.msg.Transaction).(*transaction)
if txn.state != TxnOpen {
p.log.WithField("state", txn.state).Error("Failed to send message" +
if txn.state.Load() != int32(TxnOpen) {
p.log.WithField("state", txn.state.Load()).Error("Failed to send message" +
" by a non-open transaction.")
return joinErrors(ErrTransaction,
fmt.Errorf("failed to send message by a non-open transaction"))
Expand Down
112 changes: 64 additions & 48 deletions pulsar/transaction_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package pulsar

import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
Expand All @@ -33,9 +35,9 @@ type subscription struct {
}

type transaction struct {
sync.Mutex
mu sync.Mutex
txnID TxnID
state TxnState
state atomic.Int32
tcClient *transactionCoordinatorClient
registerPartitions map[string]bool
registerAckSubscriptions map[subscription]bool
Expand All @@ -54,96 +56,106 @@ type transaction struct {
// 1. When the transaction is committed or aborted, a bool will be read from opsFlow chan.
// 2. When the opsCount increment from 0 to 1, a bool will be read from opsFlow chan.
opsFlow chan bool
opsCount int32
opsCount atomic.Int32
opTimeout time.Duration
log log.Logger
}

func newTransaction(id TxnID, tcClient *transactionCoordinatorClient, timeout time.Duration) *transaction {
transaction := &transaction{
txnID: id,
state: TxnOpen,
registerPartitions: make(map[string]bool),
registerAckSubscriptions: make(map[subscription]bool),
opsFlow: make(chan bool, 1),
opTimeout: 5 * time.Second,
opTimeout: tcClient.client.operationTimeout,
tcClient: tcClient,
}
//This means there are not pending requests with this transaction. The transaction can be committed or aborted.
transaction.state.Store(int32(TxnOpen))
// This means there are not pending requests with this transaction. The transaction can be committed or aborted.
transaction.opsFlow <- true
go func() {
//Set the state of the transaction to timeout after timeout
// Set the state of the transaction to timeout after timeout
<-time.After(timeout)
atomic.CompareAndSwapInt32((*int32)(&transaction.state), int32(TxnOpen), int32(TxnTimeout))
transaction.state.CompareAndSwap(int32(TxnOpen), int32(TxnTimeout))
}()
transaction.log = tcClient.log.SubLogger(log.Fields{})
return transaction
}

func (txn *transaction) GetState() TxnState {
return txn.state
return TxnState(txn.state.Load())
}

func (txn *transaction) Commit(_ context.Context) error {
if !(atomic.CompareAndSwapInt32((*int32)(&txn.state), int32(TxnOpen), int32(TxnCommitting)) ||
txn.state == TxnCommitting) {
return newError(InvalidStatus, "Expect transaction state is TxnOpen but "+txn.state.string())
func (txn *transaction) Commit(ctx context.Context) error {
if !(txn.state.CompareAndSwap(int32(TxnOpen), int32(TxnCommitting))) {
txnState := txn.state.Load()
return newError(InvalidStatus, txnStateErrorMessage(TxnOpen, TxnState(txnState)))
}

//Wait for all operations to complete
// Wait for all operations to complete
select {
case <-txn.opsFlow:
case <-ctx.Done():
txn.state.Store(int32(TxnOpen))
return ctx.Err()
case <-time.After(txn.opTimeout):
txn.state.Store(int32(TxnTimeout))
return newError(TimeoutError, "There are some operations that are not completed after the timeout.")
}
//Send commit transaction command to transaction coordinator
// Send commit transaction command to transaction coordinator
err := txn.tcClient.endTxn(&txn.txnID, pb.TxnAction_COMMIT)
if err == nil {
atomic.StoreInt32((*int32)(&txn.state), int32(TxnCommitted))
txn.state.Store(int32(TxnCommitted))
} else {
if e, ok := err.(*Error); ok && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
atomic.StoreInt32((*int32)(&txn.state), int32(TxnError))
var e *Error
if errors.As(err, &e) && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
txn.state.Store(int32(TxnError))
return err
}
txn.opsFlow <- true
}
return err
}

func (txn *transaction) Abort(_ context.Context) error {
if !(atomic.CompareAndSwapInt32((*int32)(&txn.state), int32(TxnOpen), int32(TxnAborting)) ||
txn.state == TxnAborting) {
return newError(InvalidStatus, "Expect transaction state is TxnOpen but "+txn.state.string())
func (txn *transaction) Abort(ctx context.Context) error {
if !(txn.state.CompareAndSwap(int32(TxnOpen), int32(TxnAborting))) {
txnState := txn.state.Load()
return newError(InvalidStatus, txnStateErrorMessage(TxnOpen, TxnState(txnState)))
}

//Wait for all operations to complete
// Wait for all operations to complete
select {
case <-txn.opsFlow:
case <-ctx.Done():
txn.state.Store(int32(TxnOpen))
return ctx.Err()
case <-time.After(txn.opTimeout):
txn.state.Store(int32(TxnTimeout))
return newError(TimeoutError, "There are some operations that are not completed after the timeout.")
}
//Send abort transaction command to transaction coordinator
// Send abort transaction command to transaction coordinator
err := txn.tcClient.endTxn(&txn.txnID, pb.TxnAction_ABORT)
if err == nil {
atomic.StoreInt32((*int32)(&txn.state), int32(TxnAborted))
txn.state.Store(int32(TxnAborted))
} else {
if e, ok := err.(*Error); ok && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
atomic.StoreInt32((*int32)(&txn.state), int32(TxnError))
} else {
txn.opsFlow <- true
var e *Error
if errors.As(err, &e) && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
txn.state.Store(int32(TxnError))
return err
}
txn.opsFlow <- true
}
return err
}

func (txn *transaction) registerSendOrAckOp() error {
if atomic.AddInt32(&txn.opsCount, 1) == 1 {
//There are new operations that not completed
if txn.opsCount.Add(1) == 1 {
// There are new operations that were not completed
select {
case <-txn.opsFlow:
return nil
case <-time.After(txn.opTimeout):
if _, err := txn.checkIfOpen(); err != nil {
if err := txn.verifyOpen(); err != nil {
return err
}
return newError(TimeoutError, "Failed to get the semaphore to register the send/ack operation")
Expand All @@ -154,23 +166,22 @@ func (txn *transaction) registerSendOrAckOp() error {

func (txn *transaction) endSendOrAckOp(err error) {
if err != nil {
atomic.StoreInt32((*int32)(&txn.state), int32(TxnError))
txn.state.Store(int32(TxnError))
}
if atomic.AddInt32(&txn.opsCount, -1) == 0 {
//This means there are not pending send/ack requests
if txn.opsCount.Add(-1) == 0 {
// This means there are no pending send/ack requests
txn.opsFlow <- true
}
}

func (txn *transaction) registerProducerTopic(topic string) error {
isOpen, err := txn.checkIfOpen()
if !isOpen {
if err := txn.verifyOpen(); err != nil {
return err
}
_, ok := txn.registerPartitions[topic]
if !ok {
txn.Lock()
defer txn.Unlock()
txn.mu.Lock()
defer txn.mu.Unlock()
if _, ok = txn.registerPartitions[topic]; !ok {
err := txn.tcClient.addPublishPartitionToTxn(&txn.txnID, []string{topic})
if err != nil {
Expand All @@ -183,8 +194,7 @@ func (txn *transaction) registerProducerTopic(topic string) error {
}

func (txn *transaction) registerAckTopic(topic string, subName string) error {
isOpen, err := txn.checkIfOpen()
if !isOpen {
if err := txn.verifyOpen(); err != nil {
return err
}
sub := subscription{
Expand All @@ -193,8 +203,8 @@ func (txn *transaction) registerAckTopic(topic string, subName string) error {
}
_, ok := txn.registerAckSubscriptions[sub]
if !ok {
txn.Lock()
defer txn.Unlock()
txn.mu.Lock()
defer txn.mu.Unlock()
if _, ok = txn.registerAckSubscriptions[sub]; !ok {
err := txn.tcClient.addSubscriptionToTxn(&txn.txnID, topic, subName)
if err != nil {
Expand All @@ -210,14 +220,15 @@ func (txn *transaction) GetTxnID() TxnID {
return txn.txnID
}

func (txn *transaction) checkIfOpen() (bool, error) {
if txn.state == TxnOpen {
return true, nil
func (txn *transaction) verifyOpen() error {
txnState := txn.state.Load()
if txnState != int32(TxnOpen) {
return newError(InvalidStatus, txnStateErrorMessage(TxnOpen, TxnState(txnState)))
}
return false, newError(InvalidStatus, "Expect transaction state is TxnOpen but "+txn.state.string())
return nil
}

func (state TxnState) string() string {
func (state TxnState) String() string {
switch state {
case TxnOpen:
return "TxnOpen"
Expand All @@ -237,3 +248,8 @@ func (state TxnState) string() string {
return "Unknown"
}
}

//nolint:unparam
func txnStateErrorMessage(expected, actual TxnState) string {
return fmt.Sprintf("Expected transaction state: %s, actual: %s", expected, actual)
}
Loading
Loading