From 96c0050043711de9c9f9328cbd9b8a041e9aff0a Mon Sep 17 00:00:00 2001
From: John Weldon <johnweldon4@gmail.com>
Date: Sat, 23 Sep 2017 07:15:39 -0700
Subject: [PATCH] Merge pull request #134 from judwhite/feature/fix-race

fix race conditions in conn.go
---
 .travis.yml     |  6 ++++--
 Makefile        |  2 +-
 conn.go         | 50 +++++++++++++++++++++----------------------------
 conn_test.go    |  4 ++--
 debug.go        |  2 +-
 dn.go           |  4 ++--
 error_test.go   |  4 ++--
 example_test.go | 10 +++++-----
 ldap.go         |  2 +-
 passwdmodify.go |  8 ++++----
 search_test.go  |  6 +++---
 11 files changed, 46 insertions(+), 52 deletions(-)

diff --git a/.travis.yml b/.travis.yml
index e32a2aa7..9782c9ba 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,8 +1,8 @@
 language: go
 env:
     global:
-        - VET_VERSIONS="1.6 1.7 tip"
-        - LINT_VERSIONS="1.6 1.7 tip"
+        - VET_VERSIONS="1.6 1.7 1.8 1.9 tip"
+        - LINT_VERSIONS="1.6 1.7 1.8 1.9 tip"
 go:
     - 1.2
     - 1.3
@@ -10,6 +10,8 @@ go:
     - 1.5
     - 1.6
     - 1.7
+    - 1.8
+    - 1.9
     - tip
 matrix:
     fast_finish: true
diff --git a/Makefile b/Makefile
index f7899f59..a9d351c7 100644
--- a/Makefile
+++ b/Makefile
@@ -7,7 +7,7 @@ IS_OLD_GO := $(shell test $(GO_VERSION) -le 2 && echo true)
 ifeq ($(IS_OLD_GO),true)
 	RACE_FLAG :=
 else
-	RACE_FLAG := -race
+	RACE_FLAG := -race -cpu 1,2,4
 endif
 
 default: fmt vet lint build quicktest
diff --git a/conn.go b/conn.go
index e701a9b6..eb28eb47 100644
--- a/conn.go
+++ b/conn.go
@@ -83,20 +83,18 @@ const (
 type Conn struct {
 	conn                net.Conn
 	isTLS               bool
-	closeCount          uint32
+	closing             uint32
 	closeErr            atomicValue
 	isStartingTLS       bool
 	Debug               debugging
-	chanConfirm         chan bool
+	chanConfirm         chan struct{}
 	messageContexts     map[int64]*messageContext
 	chanMessage         chan *messagePacket
 	chanMessageID       chan int64
-	wgSender            sync.WaitGroup
 	wgClose             sync.WaitGroup
-	once                sync.Once
 	outstandingRequests uint
 	messageMutex        sync.Mutex
-	requestTimeout      time.Duration
+	requestTimeout      int64
 }
 
 var _ Client = &Conn{}
@@ -143,7 +141,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
 func NewConn(conn net.Conn, isTLS bool) *Conn {
 	return &Conn{
 		conn:            conn,
-		chanConfirm:     make(chan bool),
+		chanConfirm:     make(chan struct{}),
 		chanMessageID:   make(chan int64),
 		chanMessage:     make(chan *messagePacket, 10),
 		messageContexts: map[int64]*messageContext{},
@@ -161,20 +159,20 @@ func (l *Conn) Start() {
 
 // isClosing returns whether or not we're currently closing.
 func (l *Conn) isClosing() bool {
-	return atomic.LoadUint32(&l.closeCount) > 0
+	return atomic.LoadUint32(&l.closing) == 1
 }
 
 // setClosing sets the closing value to true
-func (l *Conn) setClosing() {
-	atomic.AddUint32(&l.closeCount, 1)
+func (l *Conn) setClosing() bool {
+	return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
 }
 
 // Close closes the connection.
 func (l *Conn) Close() {
-	l.once.Do(func() {
-		l.setClosing()
-		l.wgSender.Wait()
+	l.messageMutex.Lock()
+	defer l.messageMutex.Unlock()
 
+	if l.setClosing() {
 		l.Debug.Printf("Sending quit message and waiting for confirmation")
 		l.chanMessage <- &messagePacket{Op: MessageQuit}
 		<-l.chanConfirm
@@ -182,27 +180,25 @@ func (l *Conn) Close() {
 
 		l.Debug.Printf("Closing network connection")
 		if err := l.conn.Close(); err != nil {
-			log.Print(err)
+			log.Println(err)
 		}
 
 		l.wgClose.Done()
-	})
+	}
 	l.wgClose.Wait()
 }
 
 // SetTimeout sets the time after a request is sent that a MessageTimeout triggers
 func (l *Conn) SetTimeout(timeout time.Duration) {
 	if timeout > 0 {
-		l.requestTimeout = timeout
+		atomic.StoreInt64(&l.requestTimeout, int64(timeout))
 	}
 }
 
 // Returns the next available messageID
 func (l *Conn) nextMessageID() int64 {
-	if l.chanMessageID != nil {
-		if messageID, ok := <-l.chanMessageID; ok {
-			return messageID
-		}
+	if messageID, ok := <-l.chanMessageID; ok {
+		return messageID
 	}
 	return 0
 }
@@ -327,12 +323,12 @@ func (l *Conn) finishMessage(msgCtx *messageContext) {
 }
 
 func (l *Conn) sendProcessMessage(message *messagePacket) bool {
+	l.messageMutex.Lock()
+	defer l.messageMutex.Unlock()
 	if l.isClosing() {
 		return false
 	}
-	l.wgSender.Add(1)
 	l.chanMessage <- message
-	l.wgSender.Done()
 	return true
 }
 
@@ -352,7 +348,6 @@ func (l *Conn) processMessages() {
 			delete(l.messageContexts, messageID)
 		}
 		close(l.chanMessageID)
-		l.chanConfirm <- true
 		close(l.chanConfirm)
 	}()
 
@@ -361,11 +356,7 @@ func (l *Conn) processMessages() {
 		select {
 		case l.chanMessageID <- messageID:
 			messageID++
-		case message, ok := <-l.chanMessage:
-			if !ok {
-				l.Debug.Printf("Shutting down - message channel is closed")
-				return
-			}
+		case message := <-l.chanMessage:
 			switch message.Op {
 			case MessageQuit:
 				l.Debug.Printf("Shutting down - quit message received")
@@ -388,14 +379,15 @@ func (l *Conn) processMessages() {
 				l.messageContexts[message.MessageID] = message.Context
 
 				// Add timeout if defined
-				if l.requestTimeout > 0 {
+				requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
+				if requestTimeout > 0 {
 					go func() {
 						defer func() {
 							if err := recover(); err != nil {
 								log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
 							}
 						}()
-						time.Sleep(l.requestTimeout)
+						time.Sleep(requestTimeout)
 						timeoutMessage := &messagePacket{
 							Op:        MessageTimeout,
 							MessageID: message.MessageID,
diff --git a/conn_test.go b/conn_test.go
index 30554d23..488754d1 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -188,7 +188,7 @@ func runWithTimeout(t *testing.T, timeout time.Duration, f func()) {
 	}
 }
 
-// packetTranslatorConn is a helful type which can be used with various tests
+// packetTranslatorConn is a helpful type which can be used with various tests
 // in this package. It implements the net.Conn interface to be used as an
 // underlying connection for a *ldap.Conn. Most methods are no-ops but the
 // Read() and Write() methods are able to translate ber-encoded packets for
@@ -241,7 +241,7 @@ func (c *packetTranslatorConn) Read(b []byte) (n int, err error) {
 }
 
 // SendResponse writes the given response packet to the response buffer for
-// this conection, signalling any goroutine waiting to read a response.
+// this connection, signalling any goroutine waiting to read a response.
 func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error {
 	c.lock.Lock()
 	defer c.lock.Unlock()
diff --git a/debug.go b/debug.go
index b8a7ecbf..7279fc25 100644
--- a/debug.go
+++ b/debug.go
@@ -6,7 +6,7 @@ import (
 	"gopkg.in/asn1-ber.v1"
 )
 
-// debbuging type
+// debugging type
 //     - has a Printf method to write the debug output
 type debugging bool
 
diff --git a/dn.go b/dn.go
index 857b2ca7..34e9023a 100644
--- a/dn.go
+++ b/dn.go
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 //
-// File contains DN parsing functionallity
+// File contains DN parsing functionality
 //
 // https://tools.ietf.org/html/rfc4514
 //
@@ -52,7 +52,7 @@ import (
 	"fmt"
 	"strings"
 
-	ber "gopkg.in/asn1-ber.v1"
+	"gopkg.in/asn1-ber.v1"
 )
 
 // AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514
diff --git a/error_test.go b/error_test.go
index c010ebe3..e456431b 100644
--- a/error_test.go
+++ b/error_test.go
@@ -49,7 +49,7 @@ func TestConnReadErr(t *testing.T) {
 	// Send the signal after a short amount of time.
 	time.AfterFunc(10*time.Millisecond, func() { conn.signals <- expectedError })
 
-	// This should block until the underlyiny conn gets the error signal
+	// This should block until the underlying conn gets the error signal
 	// which should bubble up through the reader() goroutine, close the
 	// connection, and
 	_, err := ldapConn.Search(searchReq)
@@ -58,7 +58,7 @@ func TestConnReadErr(t *testing.T) {
 	}
 }
 
-// signalErrConn is a helful type used with TestConnReadErr. It implements the
+// signalErrConn is a helpful type used with TestConnReadErr. It implements the
 // net.Conn interface to be used as a connection for the test. Most methods are
 // no-ops but the Read() method blocks until it receives a signal which it
 // returns as an error.
diff --git a/example_test.go b/example_test.go
index 821189bd..650af0a4 100644
--- a/example_test.go
+++ b/example_test.go
@@ -9,7 +9,7 @@ import (
 )
 
 // ExampleConn_Bind demonstrates how to bind a connection to an ldap user
-// allowing access to restricted attrabutes that user has access to
+// allowing access to restricted attributes that user has access to
 func ExampleConn_Bind() {
 	l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389))
 	if err != nil {
@@ -63,10 +63,10 @@ func ExampleConn_StartTLS() {
 		log.Fatal(err)
 	}
 
-	// Opertations via l are now encrypted
+	// Operations via l are now encrypted
 }
 
-// ExampleConn_Compare demonstrates how to comapre an attribute with a value
+// ExampleConn_Compare demonstrates how to compare an attribute with a value
 func ExampleConn_Compare() {
 	l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389))
 	if err != nil {
@@ -215,7 +215,7 @@ func Example_userAuthentication() {
 		log.Fatal(err)
 	}
 
-	// Rebind as the read only user for any futher queries
+	// Rebind as the read only user for any further queries
 	err = l.Bind(bindusername, bindpassword)
 	if err != nil {
 		log.Fatal(err)
@@ -240,7 +240,7 @@ func Example_beherappolicy() {
 	if ppolicyControl != nil {
 		ppolicy = ppolicyControl.(*ldap.ControlBeheraPasswordPolicy)
 	} else {
-		log.Printf("ppolicyControl response not avaliable.\n")
+		log.Printf("ppolicyControl response not available.\n")
 	}
 	if err != nil {
 		errStr := "ERROR: Cannot bind: " + err.Error()
diff --git a/ldap.go b/ldap.go
index d27e639d..49692475 100644
--- a/ldap.go
+++ b/ldap.go
@@ -9,7 +9,7 @@ import (
 	"io/ioutil"
 	"os"
 
-	ber "gopkg.in/asn1-ber.v1"
+	"gopkg.in/asn1-ber.v1"
 )
 
 // LDAP Application Codes
diff --git a/passwdmodify.go b/passwdmodify.go
index 26110ccf..7d8246fd 100644
--- a/passwdmodify.go
+++ b/passwdmodify.go
@@ -135,10 +135,10 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa
 	extendedResponse := packet.Children[1]
 	for _, child := range extendedResponse.Children {
 		if child.Tag == 11 {
-			passwordModifyReponseValue := ber.DecodePacket(child.Data.Bytes())
-			if len(passwordModifyReponseValue.Children) == 1 {
-				if passwordModifyReponseValue.Children[0].Tag == 0 {
-					result.GeneratedPassword = ber.DecodeString(passwordModifyReponseValue.Children[0].Data.Bytes())
+			passwordModifyResponseValue := ber.DecodePacket(child.Data.Bytes())
+			if len(passwordModifyResponseValue.Children) == 1 {
+				if passwordModifyResponseValue.Children[0].Tag == 0 {
+					result.GeneratedPassword = ber.DecodeString(passwordModifyResponseValue.Children[0].Data.Bytes())
 				}
 			}
 		}
diff --git a/search_test.go b/search_test.go
index efb8147d..5f77b22e 100644
--- a/search_test.go
+++ b/search_test.go
@@ -15,7 +15,7 @@ func TestNewEntry(t *testing.T) {
 		"delta":   {"value"},
 		"epsilon": {"value"},
 	}
-	exectedEntry := NewEntry(dn, attributes)
+	executedEntry := NewEntry(dn, attributes)
 
 	iteration := 0
 	for {
@@ -23,8 +23,8 @@ func TestNewEntry(t *testing.T) {
 			break
 		}
 		testEntry := NewEntry(dn, attributes)
-		if !reflect.DeepEqual(exectedEntry, testEntry) {
-			t.Fatalf("consequent calls to NewEntry did not yield the same result:\n\texpected:\n\t%s\n\tgot:\n\t%s\n", exectedEntry, testEntry)
+		if !reflect.DeepEqual(executedEntry, testEntry) {
+			t.Fatalf("subsequent calls to NewEntry did not yield the same result:\n\texpected:\n\t%s\n\tgot:\n\t%s\n", executedEntry, testEntry)
 		}
 		iteration = iteration + 1
 	}