Skip to content

Commit

Permalink
Merge pull request #134 from judwhite/feature/fix-race
Browse files Browse the repository at this point in the history
fix race conditions in conn.go
  • Loading branch information
johnweldon authored Sep 23, 2017
2 parents 95ede12 + 3ca927c commit 3de5b9b
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 53 deletions.
6 changes: 4 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
language: go
env:
global:
- VET_VERSIONS="1.6 1.7 tip"
- LINT_VERSIONS="1.6 1.7 tip"
- VET_VERSIONS="1.6 1.7 1.8 1.9 tip"
- LINT_VERSIONS="1.6 1.7 1.8 1.9 tip"
go:
- 1.2
- 1.3
- 1.4
- 1.5
- 1.6
- 1.7
- 1.8
- 1.9
- tip
matrix:
fast_finish: true
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ IS_OLD_GO := $(shell test $(GO_VERSION) -le 2 && echo true)
ifeq ($(IS_OLD_GO),true)
RACE_FLAG :=
else
RACE_FLAG := -race
RACE_FLAG := -race -cpu 1,2,4
endif

default: fmt vet lint build quicktest
Expand Down
2 changes: 1 addition & 1 deletion bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package ldap
import (
"errors"

ber "gopkg.in/asn1-ber.v1"
"gopkg.in/asn1-ber.v1"
)

// SimpleBindRequest represents a username/password bind operation
Expand Down
50 changes: 21 additions & 29 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,18 @@ const (
type Conn struct {
conn net.Conn
isTLS bool
closeCount uint32
closing uint32
closeErr atomicValue
isStartingTLS bool
Debug debugging
chanConfirm chan bool
chanConfirm chan struct{}
messageContexts map[int64]*messageContext
chanMessage chan *messagePacket
chanMessageID chan int64
wgSender sync.WaitGroup
wgClose sync.WaitGroup
once sync.Once
outstandingRequests uint
messageMutex sync.Mutex
requestTimeout time.Duration
requestTimeout int64
}

var _ Client = &Conn{}
Expand Down Expand Up @@ -143,7 +141,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
func NewConn(conn net.Conn, isTLS bool) *Conn {
return &Conn{
conn: conn,
chanConfirm: make(chan bool),
chanConfirm: make(chan struct{}),
chanMessageID: make(chan int64),
chanMessage: make(chan *messagePacket, 10),
messageContexts: map[int64]*messageContext{},
Expand All @@ -161,48 +159,46 @@ func (l *Conn) Start() {

// isClosing returns whether or not we're currently closing.
func (l *Conn) isClosing() bool {
return atomic.LoadUint32(&l.closeCount) > 0
return atomic.LoadUint32(&l.closing) == 1
}

// setClosing sets the closing value to true
func (l *Conn) setClosing() {
atomic.AddUint32(&l.closeCount, 1)
func (l *Conn) setClosing() bool {
return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
}

// Close closes the connection.
func (l *Conn) Close() {
l.once.Do(func() {
l.setClosing()
l.wgSender.Wait()
l.messageMutex.Lock()
defer l.messageMutex.Unlock()

if l.setClosing() {
l.Debug.Printf("Sending quit message and waiting for confirmation")
l.chanMessage <- &messagePacket{Op: MessageQuit}
<-l.chanConfirm
close(l.chanMessage)

l.Debug.Printf("Closing network connection")
if err := l.conn.Close(); err != nil {
log.Print(err)
log.Println(err)
}

l.wgClose.Done()
})
}
l.wgClose.Wait()
}

// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
func (l *Conn) SetTimeout(timeout time.Duration) {
if timeout > 0 {
l.requestTimeout = timeout
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
}
}

// Returns the next available messageID
func (l *Conn) nextMessageID() int64 {
if l.chanMessageID != nil {
if messageID, ok := <-l.chanMessageID; ok {
return messageID
}
if messageID, ok := <-l.chanMessageID; ok {
return messageID
}
return 0
}
Expand Down Expand Up @@ -327,12 +323,12 @@ func (l *Conn) finishMessage(msgCtx *messageContext) {
}

func (l *Conn) sendProcessMessage(message *messagePacket) bool {
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
if l.isClosing() {
return false
}
l.wgSender.Add(1)
l.chanMessage <- message
l.wgSender.Done()
return true
}

Expand All @@ -352,7 +348,6 @@ func (l *Conn) processMessages() {
delete(l.messageContexts, messageID)
}
close(l.chanMessageID)
l.chanConfirm <- true
close(l.chanConfirm)
}()

Expand All @@ -361,11 +356,7 @@ func (l *Conn) processMessages() {
select {
case l.chanMessageID <- messageID:
messageID++
case message, ok := <-l.chanMessage:
if !ok {
l.Debug.Printf("Shutting down - message channel is closed")
return
}
case message := <-l.chanMessage:
switch message.Op {
case MessageQuit:
l.Debug.Printf("Shutting down - quit message received")
Expand All @@ -388,14 +379,15 @@ func (l *Conn) processMessages() {
l.messageContexts[message.MessageID] = message.Context

// Add timeout if defined
if l.requestTimeout > 0 {
requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
if requestTimeout > 0 {
go func() {
defer func() {
if err := recover(); err != nil {
log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
}
}()
time.Sleep(l.requestTimeout)
time.Sleep(requestTimeout)
timeoutMessage := &messagePacket{
Op: MessageTimeout,
MessageID: message.MessageID,
Expand Down
4 changes: 2 additions & 2 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func runWithTimeout(t *testing.T, timeout time.Duration, f func()) {
}
}

// packetTranslatorConn is a helful type which can be used with various tests
// packetTranslatorConn is a helpful type which can be used with various tests
// in this package. It implements the net.Conn interface to be used as an
// underlying connection for a *ldap.Conn. Most methods are no-ops but the
// Read() and Write() methods are able to translate ber-encoded packets for
Expand Down Expand Up @@ -241,7 +241,7 @@ func (c *packetTranslatorConn) Read(b []byte) (n int, err error) {
}

// SendResponse writes the given response packet to the response buffer for
// this conection, signalling any goroutine waiting to read a response.
// this connection, signalling any goroutine waiting to read a response.
func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error {
c.lock.Lock()
defer c.lock.Unlock()
Expand Down
2 changes: 1 addition & 1 deletion debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"gopkg.in/asn1-ber.v1"
)

// debbuging type
// debugging type
// - has a Printf method to write the debug output
type debugging bool

Expand Down
4 changes: 2 additions & 2 deletions dn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// File contains DN parsing functionallity
// File contains DN parsing functionality
//
// https://tools.ietf.org/html/rfc4514
//
Expand Down Expand Up @@ -52,7 +52,7 @@ import (
"fmt"
"strings"

ber "gopkg.in/asn1-ber.v1"
"gopkg.in/asn1-ber.v1"
)

// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514
Expand Down
4 changes: 2 additions & 2 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestConnReadErr(t *testing.T) {
// Send the signal after a short amount of time.
time.AfterFunc(10*time.Millisecond, func() { conn.signals <- expectedError })

// This should block until the underlyiny conn gets the error signal
// This should block until the underlying conn gets the error signal
// which should bubble up through the reader() goroutine, close the
// connection, and
_, err := ldapConn.Search(searchReq)
Expand All @@ -58,7 +58,7 @@ func TestConnReadErr(t *testing.T) {
}
}

// signalErrConn is a helful type used with TestConnReadErr. It implements the
// signalErrConn is a helpful type used with TestConnReadErr. It implements the
// net.Conn interface to be used as a connection for the test. Most methods are
// no-ops but the Read() method blocks until it receives a signal which it
// returns as an error.
Expand Down
10 changes: 5 additions & 5 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

// ExampleConn_Bind demonstrates how to bind a connection to an ldap user
// allowing access to restricted attrabutes that user has access to
// allowing access to restricted attributes that user has access to
func ExampleConn_Bind() {
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389))
if err != nil {
Expand Down Expand Up @@ -63,10 +63,10 @@ func ExampleConn_StartTLS() {
log.Fatal(err)
}

// Opertations via l are now encrypted
// Operations via l are now encrypted
}

// ExampleConn_Compare demonstrates how to comapre an attribute with a value
// ExampleConn_Compare demonstrates how to compare an attribute with a value
func ExampleConn_Compare() {
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389))
if err != nil {
Expand Down Expand Up @@ -215,7 +215,7 @@ func Example_userAuthentication() {
log.Fatal(err)
}

// Rebind as the read only user for any futher queries
// Rebind as the read only user for any further queries
err = l.Bind(bindusername, bindpassword)
if err != nil {
log.Fatal(err)
Expand All @@ -240,7 +240,7 @@ func Example_beherappolicy() {
if ppolicyControl != nil {
ppolicy = ppolicyControl.(*ldap.ControlBeheraPasswordPolicy)
} else {
log.Printf("ppolicyControl response not avaliable.\n")
log.Printf("ppolicyControl response not available.\n")
}
if err != nil {
errStr := "ERROR: Cannot bind: " + err.Error()
Expand Down
2 changes: 1 addition & 1 deletion ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"io/ioutil"
"os"

ber "gopkg.in/asn1-ber.v1"
"gopkg.in/asn1-ber.v1"
)

// LDAP Application Codes
Expand Down
8 changes: 4 additions & 4 deletions passwdmodify.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa
extendedResponse := packet.Children[1]
for _, child := range extendedResponse.Children {
if child.Tag == 11 {
passwordModifyReponseValue := ber.DecodePacket(child.Data.Bytes())
if len(passwordModifyReponseValue.Children) == 1 {
if passwordModifyReponseValue.Children[0].Tag == 0 {
result.GeneratedPassword = ber.DecodeString(passwordModifyReponseValue.Children[0].Data.Bytes())
passwordModifyResponseValue := ber.DecodePacket(child.Data.Bytes())
if len(passwordModifyResponseValue.Children) == 1 {
if passwordModifyResponseValue.Children[0].Tag == 0 {
result.GeneratedPassword = ber.DecodeString(passwordModifyResponseValue.Children[0].Data.Bytes())
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ func TestNewEntry(t *testing.T) {
"delta": {"value"},
"epsilon": {"value"},
}
exectedEntry := NewEntry(dn, attributes)
executedEntry := NewEntry(dn, attributes)

iteration := 0
for {
if iteration == 100 {
break
}
testEntry := NewEntry(dn, attributes)
if !reflect.DeepEqual(exectedEntry, testEntry) {
t.Fatalf("consequent calls to NewEntry did not yield the same result:\n\texpected:\n\t%s\n\tgot:\n\t%s\n", exectedEntry, testEntry)
if !reflect.DeepEqual(executedEntry, testEntry) {
t.Fatalf("subsequent calls to NewEntry did not yield the same result:\n\texpected:\n\t%s\n\tgot:\n\t%s\n", executedEntry, testEntry)
}
iteration = iteration + 1
}
Expand Down

0 comments on commit 3de5b9b

Please sign in to comment.