From 96c0050043711de9c9f9328cbd9b8a041e9aff0a Mon Sep 17 00:00:00 2001 From: John Weldon <johnweldon4@gmail.com> Date: Sat, 23 Sep 2017 07:15:39 -0700 Subject: [PATCH] Merge pull request #134 from judwhite/feature/fix-race fix race conditions in conn.go --- .travis.yml | 6 ++++-- Makefile | 2 +- conn.go | 50 +++++++++++++++++++++---------------------------- conn_test.go | 4 ++-- debug.go | 2 +- dn.go | 4 ++-- error_test.go | 4 ++-- example_test.go | 10 +++++----- ldap.go | 2 +- passwdmodify.go | 8 ++++---- search_test.go | 6 +++--- 11 files changed, 46 insertions(+), 52 deletions(-) diff --git a/.travis.yml b/.travis.yml index e32a2aa7..9782c9ba 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: go env: global: - - VET_VERSIONS="1.6 1.7 tip" - - LINT_VERSIONS="1.6 1.7 tip" + - VET_VERSIONS="1.6 1.7 1.8 1.9 tip" + - LINT_VERSIONS="1.6 1.7 1.8 1.9 tip" go: - 1.2 - 1.3 @@ -10,6 +10,8 @@ go: - 1.5 - 1.6 - 1.7 + - 1.8 + - 1.9 - tip matrix: fast_finish: true diff --git a/Makefile b/Makefile index f7899f59..a9d351c7 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ IS_OLD_GO := $(shell test $(GO_VERSION) -le 2 && echo true) ifeq ($(IS_OLD_GO),true) RACE_FLAG := else - RACE_FLAG := -race + RACE_FLAG := -race -cpu 1,2,4 endif default: fmt vet lint build quicktest diff --git a/conn.go b/conn.go index e701a9b6..eb28eb47 100644 --- a/conn.go +++ b/conn.go @@ -83,20 +83,18 @@ const ( type Conn struct { conn net.Conn isTLS bool - closeCount uint32 + closing uint32 closeErr atomicValue isStartingTLS bool Debug debugging - chanConfirm chan bool + chanConfirm chan struct{} messageContexts map[int64]*messageContext chanMessage chan *messagePacket chanMessageID chan int64 - wgSender sync.WaitGroup wgClose sync.WaitGroup - once sync.Once outstandingRequests uint messageMutex sync.Mutex - requestTimeout time.Duration + requestTimeout int64 } var _ Client = &Conn{} @@ -143,7 +141,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { func NewConn(conn net.Conn, isTLS bool) *Conn { return &Conn{ conn: conn, - chanConfirm: make(chan bool), + chanConfirm: make(chan struct{}), chanMessageID: make(chan int64), chanMessage: make(chan *messagePacket, 10), messageContexts: map[int64]*messageContext{}, @@ -161,20 +159,20 @@ func (l *Conn) Start() { // isClosing returns whether or not we're currently closing. func (l *Conn) isClosing() bool { - return atomic.LoadUint32(&l.closeCount) > 0 + return atomic.LoadUint32(&l.closing) == 1 } // setClosing sets the closing value to true -func (l *Conn) setClosing() { - atomic.AddUint32(&l.closeCount, 1) +func (l *Conn) setClosing() bool { + return atomic.CompareAndSwapUint32(&l.closing, 0, 1) } // Close closes the connection. func (l *Conn) Close() { - l.once.Do(func() { - l.setClosing() - l.wgSender.Wait() + l.messageMutex.Lock() + defer l.messageMutex.Unlock() + if l.setClosing() { l.Debug.Printf("Sending quit message and waiting for confirmation") l.chanMessage <- &messagePacket{Op: MessageQuit} <-l.chanConfirm @@ -182,27 +180,25 @@ func (l *Conn) Close() { l.Debug.Printf("Closing network connection") if err := l.conn.Close(); err != nil { - log.Print(err) + log.Println(err) } l.wgClose.Done() - }) + } l.wgClose.Wait() } // SetTimeout sets the time after a request is sent that a MessageTimeout triggers func (l *Conn) SetTimeout(timeout time.Duration) { if timeout > 0 { - l.requestTimeout = timeout + atomic.StoreInt64(&l.requestTimeout, int64(timeout)) } } // Returns the next available messageID func (l *Conn) nextMessageID() int64 { - if l.chanMessageID != nil { - if messageID, ok := <-l.chanMessageID; ok { - return messageID - } + if messageID, ok := <-l.chanMessageID; ok { + return messageID } return 0 } @@ -327,12 +323,12 @@ func (l *Conn) finishMessage(msgCtx *messageContext) { } func (l *Conn) sendProcessMessage(message *messagePacket) bool { + l.messageMutex.Lock() + defer l.messageMutex.Unlock() if l.isClosing() { return false } - l.wgSender.Add(1) l.chanMessage <- message - l.wgSender.Done() return true } @@ -352,7 +348,6 @@ func (l *Conn) processMessages() { delete(l.messageContexts, messageID) } close(l.chanMessageID) - l.chanConfirm <- true close(l.chanConfirm) }() @@ -361,11 +356,7 @@ func (l *Conn) processMessages() { select { case l.chanMessageID <- messageID: messageID++ - case message, ok := <-l.chanMessage: - if !ok { - l.Debug.Printf("Shutting down - message channel is closed") - return - } + case message := <-l.chanMessage: switch message.Op { case MessageQuit: l.Debug.Printf("Shutting down - quit message received") @@ -388,14 +379,15 @@ func (l *Conn) processMessages() { l.messageContexts[message.MessageID] = message.Context // Add timeout if defined - if l.requestTimeout > 0 { + requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout)) + if requestTimeout > 0 { go func() { defer func() { if err := recover(); err != nil { log.Printf("ldap: recovered panic in RequestTimeout: %v", err) } }() - time.Sleep(l.requestTimeout) + time.Sleep(requestTimeout) timeoutMessage := &messagePacket{ Op: MessageTimeout, MessageID: message.MessageID, diff --git a/conn_test.go b/conn_test.go index 30554d23..488754d1 100644 --- a/conn_test.go +++ b/conn_test.go @@ -188,7 +188,7 @@ func runWithTimeout(t *testing.T, timeout time.Duration, f func()) { } } -// packetTranslatorConn is a helful type which can be used with various tests +// packetTranslatorConn is a helpful type which can be used with various tests // in this package. It implements the net.Conn interface to be used as an // underlying connection for a *ldap.Conn. Most methods are no-ops but the // Read() and Write() methods are able to translate ber-encoded packets for @@ -241,7 +241,7 @@ func (c *packetTranslatorConn) Read(b []byte) (n int, err error) { } // SendResponse writes the given response packet to the response buffer for -// this conection, signalling any goroutine waiting to read a response. +// this connection, signalling any goroutine waiting to read a response. func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error { c.lock.Lock() defer c.lock.Unlock() diff --git a/debug.go b/debug.go index b8a7ecbf..7279fc25 100644 --- a/debug.go +++ b/debug.go @@ -6,7 +6,7 @@ import ( "gopkg.in/asn1-ber.v1" ) -// debbuging type +// debugging type // - has a Printf method to write the debug output type debugging bool diff --git a/dn.go b/dn.go index 857b2ca7..34e9023a 100644 --- a/dn.go +++ b/dn.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // -// File contains DN parsing functionallity +// File contains DN parsing functionality // // https://tools.ietf.org/html/rfc4514 // @@ -52,7 +52,7 @@ import ( "fmt" "strings" - ber "gopkg.in/asn1-ber.v1" + "gopkg.in/asn1-ber.v1" ) // AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514 diff --git a/error_test.go b/error_test.go index c010ebe3..e456431b 100644 --- a/error_test.go +++ b/error_test.go @@ -49,7 +49,7 @@ func TestConnReadErr(t *testing.T) { // Send the signal after a short amount of time. time.AfterFunc(10*time.Millisecond, func() { conn.signals <- expectedError }) - // This should block until the underlyiny conn gets the error signal + // This should block until the underlying conn gets the error signal // which should bubble up through the reader() goroutine, close the // connection, and _, err := ldapConn.Search(searchReq) @@ -58,7 +58,7 @@ func TestConnReadErr(t *testing.T) { } } -// signalErrConn is a helful type used with TestConnReadErr. It implements the +// signalErrConn is a helpful type used with TestConnReadErr. It implements the // net.Conn interface to be used as a connection for the test. Most methods are // no-ops but the Read() method blocks until it receives a signal which it // returns as an error. diff --git a/example_test.go b/example_test.go index 821189bd..650af0a4 100644 --- a/example_test.go +++ b/example_test.go @@ -9,7 +9,7 @@ import ( ) // ExampleConn_Bind demonstrates how to bind a connection to an ldap user -// allowing access to restricted attrabutes that user has access to +// allowing access to restricted attributes that user has access to func ExampleConn_Bind() { l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) if err != nil { @@ -63,10 +63,10 @@ func ExampleConn_StartTLS() { log.Fatal(err) } - // Opertations via l are now encrypted + // Operations via l are now encrypted } -// ExampleConn_Compare demonstrates how to comapre an attribute with a value +// ExampleConn_Compare demonstrates how to compare an attribute with a value func ExampleConn_Compare() { l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) if err != nil { @@ -215,7 +215,7 @@ func Example_userAuthentication() { log.Fatal(err) } - // Rebind as the read only user for any futher queries + // Rebind as the read only user for any further queries err = l.Bind(bindusername, bindpassword) if err != nil { log.Fatal(err) @@ -240,7 +240,7 @@ func Example_beherappolicy() { if ppolicyControl != nil { ppolicy = ppolicyControl.(*ldap.ControlBeheraPasswordPolicy) } else { - log.Printf("ppolicyControl response not avaliable.\n") + log.Printf("ppolicyControl response not available.\n") } if err != nil { errStr := "ERROR: Cannot bind: " + err.Error() diff --git a/ldap.go b/ldap.go index d27e639d..49692475 100644 --- a/ldap.go +++ b/ldap.go @@ -9,7 +9,7 @@ import ( "io/ioutil" "os" - ber "gopkg.in/asn1-ber.v1" + "gopkg.in/asn1-ber.v1" ) // LDAP Application Codes diff --git a/passwdmodify.go b/passwdmodify.go index 26110ccf..7d8246fd 100644 --- a/passwdmodify.go +++ b/passwdmodify.go @@ -135,10 +135,10 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa extendedResponse := packet.Children[1] for _, child := range extendedResponse.Children { if child.Tag == 11 { - passwordModifyReponseValue := ber.DecodePacket(child.Data.Bytes()) - if len(passwordModifyReponseValue.Children) == 1 { - if passwordModifyReponseValue.Children[0].Tag == 0 { - result.GeneratedPassword = ber.DecodeString(passwordModifyReponseValue.Children[0].Data.Bytes()) + passwordModifyResponseValue := ber.DecodePacket(child.Data.Bytes()) + if len(passwordModifyResponseValue.Children) == 1 { + if passwordModifyResponseValue.Children[0].Tag == 0 { + result.GeneratedPassword = ber.DecodeString(passwordModifyResponseValue.Children[0].Data.Bytes()) } } } diff --git a/search_test.go b/search_test.go index efb8147d..5f77b22e 100644 --- a/search_test.go +++ b/search_test.go @@ -15,7 +15,7 @@ func TestNewEntry(t *testing.T) { "delta": {"value"}, "epsilon": {"value"}, } - exectedEntry := NewEntry(dn, attributes) + executedEntry := NewEntry(dn, attributes) iteration := 0 for { @@ -23,8 +23,8 @@ func TestNewEntry(t *testing.T) { break } testEntry := NewEntry(dn, attributes) - if !reflect.DeepEqual(exectedEntry, testEntry) { - t.Fatalf("consequent calls to NewEntry did not yield the same result:\n\texpected:\n\t%s\n\tgot:\n\t%s\n", exectedEntry, testEntry) + if !reflect.DeepEqual(executedEntry, testEntry) { + t.Fatalf("subsequent calls to NewEntry did not yield the same result:\n\texpected:\n\t%s\n\tgot:\n\t%s\n", executedEntry, testEntry) } iteration = iteration + 1 }