diff --git a/conn.go b/conn.go index 6aad628b..a037aff9 100644 --- a/conn.go +++ b/conn.go @@ -54,6 +54,7 @@ type Conn struct { conn net.Conn isTLS bool isClosing bool + closeErr error isStartingTLS bool Debug debugging chanConfirm chan bool @@ -298,6 +299,11 @@ func (l *Conn) processMessages() { log.Printf("ldap: recovered panic in processMessages: %v", err) } for messageID, channel := range l.chanResults { + // If we are closing due to an error, inform anyone who + // is waiting about the error. + if l.isClosing && l.closeErr != nil { + channel <- &PacketResponse{Error: l.closeErr} + } l.Debug.Printf("Closing channel for MessageID %d", messageID) close(channel) delete(l.chanResults, messageID) @@ -324,15 +330,20 @@ func (l *Conn) processMessages() { case MessageRequest: // Add to message list and write to network l.Debug.Printf("Sending message %d", message.MessageID) - l.chanResults[message.MessageID] = message.Channel buf := message.Packet.Bytes() _, err := l.conn.Write(buf) if err != nil { l.Debug.Printf("Error Sending Message: %s", err.Error()) + message.Channel <- &PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)} + close(message.Channel) break } + // Only add to chanResults if we were able to + // successfully write the message. + l.chanResults[message.MessageID] = message.Channel + // Add timeout if defined if l.requestTimeout > 0 { go func() { @@ -397,6 +408,7 @@ func (l *Conn) reader() { if err != nil { // A read error is expected here if we are closing the connection... if !l.isClosing { + l.closeErr = fmt.Errorf("unable to read LDAP response packet: %s", err) l.Debug.Printf("reader error: %s", err.Error()) } return diff --git a/error_test.go b/error_test.go index 4ec720d9..c010ebe3 100644 --- a/error_test.go +++ b/error_test.go @@ -1,7 +1,11 @@ package ldap import ( + "errors" + "net" + "strings" "testing" + "time" "gopkg.in/asn1-ber.v1" ) @@ -16,8 +20,8 @@ func TestNilPacket(t *testing.T) { // Test for nil result kids := []*ber.Packet{ - &ber.Packet{}, // Unused - nil, // Can't be nil + {}, // Unused + nil, // Can't be nil } pack := &ber.Packet{Children: kids} code, _ = getLDAPResultCode(pack) @@ -25,5 +29,74 @@ func TestNilPacket(t *testing.T) { if code != ErrorUnexpectedResponse { t.Errorf("Should have an 'ErrorUnexpectedResponse' error in nil packets, got: %v", code) } +} + +// TestConnReadErr tests that an unexpected error reading from underlying +// connection bubbles up to the goroutine which makes a request. +func TestConnReadErr(t *testing.T) { + conn := &signalErrConn{ + signals: make(chan error), + } + + ldapConn := NewConn(conn, false) + ldapConn.Start() + + // Make a dummy search request. + searchReq := NewSearchRequest("dc=example,dc=com", ScopeWholeSubtree, DerefAlways, 0, 0, false, "(objectClass=*)", nil, nil) + + expectedError := errors.New("this is the error you are looking for") + + // Send the signal after a short amount of time. + time.AfterFunc(10*time.Millisecond, func() { conn.signals <- expectedError }) + + // This should block until the underlyiny conn gets the error signal + // which should bubble up through the reader() goroutine, close the + // connection, and + _, err := ldapConn.Search(searchReq) + if err == nil || !strings.Contains(err.Error(), expectedError.Error()) { + t.Errorf("not the expected error: %s", err) + } +} + +// signalErrConn is a helful type used with TestConnReadErr. It implements the +// net.Conn interface to be used as a connection for the test. Most methods are +// no-ops but the Read() method blocks until it receives a signal which it +// returns as an error. +type signalErrConn struct { + signals chan error +} + +// Read blocks until an error is sent on the internal signals channel. That +// error is returned. +func (c *signalErrConn) Read(b []byte) (n int, err error) { + return 0, <-c.signals +} + +func (c *signalErrConn) Write(b []byte) (n int, err error) { + return len(b), nil +} + +func (c *signalErrConn) Close() error { + close(c.signals) + return nil +} + +func (c *signalErrConn) LocalAddr() net.Addr { + return (*net.TCPAddr)(nil) +} + +func (c *signalErrConn) RemoteAddr() net.Addr { + return (*net.TCPAddr)(nil) +} + +func (c *signalErrConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *signalErrConn) SetReadDeadline(t time.Time) error { + return nil +} +func (c *signalErrConn) SetWriteDeadline(t time.Time) error { + return nil }