diff --git a/bind.go b/bind.go index 26b3cc72..432efa78 100644 --- a/bind.go +++ b/bind.go @@ -7,7 +7,7 @@ package ldap import ( "errors" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) // SimpleBindRequest represents a username/password bind operation @@ -18,6 +18,9 @@ type SimpleBindRequest struct { Password string // Controls are optional controls to send with the bind request Controls []Control + // AllowEmptyPassword sets whether the client allows binding with an empty password + // (normally used for unauthenticated bind). + AllowEmptyPassword bool } // SimpleBindResult contains the response from the server @@ -28,9 +31,10 @@ type SimpleBindResult struct { // NewSimpleBindRequest returns a bind request func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest { return &SimpleBindRequest{ - Username: username, - Password: password, - Controls: controls, + Username: username, + Password: password, + Controls: controls, + AllowEmptyPassword: false, } } @@ -47,6 +51,10 @@ func (bindRequest *SimpleBindRequest) encode() *ber.Packet { // SimpleBind performs the simple bind operation defined in the given request func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) { + if simpleBindRequest.Password == "" && !simpleBindRequest.AllowEmptyPassword { + return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client")) + } + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) encodedBindRequest := simpleBindRequest.encode() @@ -97,47 +105,33 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu return result, nil } -// Bind performs a bind with the given username and password +// Bind performs a bind with the given username and password. +// +// It does not allow unauthenticated bind (i.e. empty password). Use the UnauthenticatedBind method +// for that. func (l *Conn) Bind(username, password string) error { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") - bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) - bindRequest.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, username, "User Name")) - bindRequest.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, password, "Password")) - packet.AppendChild(bindRequest) - - if l.Debug { - ber.PrintPacket(packet) - } - - msgCtx, err := l.sendMessage(packet) - if err != nil { - return err - } - defer l.finishMessage(msgCtx) - - packetResponse, ok := <-msgCtx.responses - if !ok { - return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) - if err != nil { - return err - } - - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return err - } - ber.PrintPacket(packet) + req := &SimpleBindRequest{ + Username: username, + Password: password, + AllowEmptyPassword: false, } + _, err := l.SimpleBind(req) + return err +} - resultCode, resultDescription := getLDAPResultCode(packet) - if resultCode != 0 { - return NewError(resultCode, errors.New(resultDescription)) +// UnauthenticatedBind performs an unauthenticated bind. +// +// A username may be provided for trace (e.g. logging) purpose only, but it is normally not +// authenticated or otherwise validated by the LDAP server. +// +// See https://tools.ietf.org/html/rfc4513#section-5.1.2 . +// See https://tools.ietf.org/html/rfc4513#section-6.3.1 . +func (l *Conn) UnauthenticatedBind(username string) error { + req := &SimpleBindRequest{ + Username: username, + Password: "", + AllowEmptyPassword: true, } - - return nil + _, err := l.SimpleBind(req) + return err } diff --git a/error.go b/error.go index 4cccb537..6e1277fd 100644 --- a/error.go +++ b/error.go @@ -54,6 +54,7 @@ const ( ErrorDebugging = 203 ErrorUnexpectedMessage = 204 ErrorUnexpectedResponse = 205 + ErrorEmptyPassword = 206 ) // LDAPResultCodeMap contains string descriptions for LDAP error codes @@ -104,6 +105,7 @@ var LDAPResultCodeMap = map[uint8]string{ ErrorDebugging: "Debugging Error", ErrorUnexpectedMessage: "Unexpected Message", ErrorUnexpectedResponse: "Unexpected Response", + ErrorEmptyPassword: "Empty password not allowed by the client", } func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) { diff --git a/ldap_test.go b/ldap_test.go index 9f430518..58f8260e 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -1,11 +1,9 @@ -package ldap_test +package ldap import ( "crypto/tls" "fmt" "testing" - - "gopkg.in/ldap.v2" ) var ldapServer = "ldap.itd.umich.edu" @@ -23,7 +21,7 @@ var attributes = []string{ func TestDial(t *testing.T) { fmt.Printf("TestDial: starting...\n") - l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) if err != nil { t.Errorf(err.Error()) return @@ -34,7 +32,7 @@ func TestDial(t *testing.T) { func TestDialTLS(t *testing.T) { fmt.Printf("TestDialTLS: starting...\n") - l, err := ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true}) + l, err := DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true}) if err != nil { t.Errorf(err.Error()) return @@ -45,7 +43,7 @@ func TestDialTLS(t *testing.T) { func TestStartTLS(t *testing.T) { fmt.Printf("TestStartTLS: starting...\n") - l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) if err != nil { t.Errorf(err.Error()) return @@ -60,16 +58,16 @@ func TestStartTLS(t *testing.T) { func TestSearch(t *testing.T) { fmt.Printf("TestSearch: starting...\n") - l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) if err != nil { t.Errorf(err.Error()) return } defer l.Close() - searchRequest := ldap.NewSearchRequest( + searchRequest := NewSearchRequest( baseDN, - ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false, + ScopeWholeSubtree, DerefAlways, 0, 0, false, filter[0], attributes, nil) @@ -85,16 +83,16 @@ func TestSearch(t *testing.T) { func TestSearchStartTLS(t *testing.T) { fmt.Printf("TestSearchStartTLS: starting...\n") - l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) if err != nil { t.Errorf(err.Error()) return } defer l.Close() - searchRequest := ldap.NewSearchRequest( + searchRequest := NewSearchRequest( baseDN, - ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false, + ScopeWholeSubtree, DerefAlways, 0, 0, false, filter[0], attributes, nil) @@ -125,22 +123,22 @@ func TestSearchStartTLS(t *testing.T) { func TestSearchWithPaging(t *testing.T) { fmt.Printf("TestSearchWithPaging: starting...\n") - l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) if err != nil { t.Errorf(err.Error()) return } defer l.Close() - err = l.Bind("", "") + err = l.UnauthenticatedBind("") if err != nil { t.Errorf(err.Error()) return } - searchRequest := ldap.NewSearchRequest( + searchRequest := NewSearchRequest( baseDN, - ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false, + ScopeWholeSubtree, DerefAlways, 0, 0, false, filter[2], attributes, nil) @@ -152,12 +150,12 @@ func TestSearchWithPaging(t *testing.T) { fmt.Printf("TestSearchWithPaging: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries)) - searchRequest = ldap.NewSearchRequest( + searchRequest = NewSearchRequest( baseDN, - ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false, + ScopeWholeSubtree, DerefAlways, 0, 0, false, filter[2], attributes, - []ldap.Control{ldap.NewControlPaging(5)}) + []Control{NewControlPaging(5)}) sr, err = l.SearchWithPaging(searchRequest, 5) if err != nil { t.Errorf(err.Error()) @@ -166,12 +164,12 @@ func TestSearchWithPaging(t *testing.T) { fmt.Printf("TestSearchWithPaging: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries)) - searchRequest = ldap.NewSearchRequest( + searchRequest = NewSearchRequest( baseDN, - ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false, + ScopeWholeSubtree, DerefAlways, 0, 0, false, filter[2], attributes, - []ldap.Control{ldap.NewControlPaging(500)}) + []Control{NewControlPaging(500)}) sr, err = l.SearchWithPaging(searchRequest, 5) if err == nil { t.Errorf("expected an error when paging size in control in search request doesn't match size given in call, got none") @@ -179,10 +177,10 @@ func TestSearchWithPaging(t *testing.T) { } } -func searchGoroutine(t *testing.T, l *ldap.Conn, results chan *ldap.SearchResult, i int) { - searchRequest := ldap.NewSearchRequest( +func searchGoroutine(t *testing.T, l *Conn, results chan *SearchResult, i int) { + searchRequest := NewSearchRequest( baseDN, - ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false, + ScopeWholeSubtree, DerefAlways, 0, 0, false, filter[i], attributes, nil) @@ -197,17 +195,17 @@ func searchGoroutine(t *testing.T, l *ldap.Conn, results chan *ldap.SearchResult func testMultiGoroutineSearch(t *testing.T, TLS bool, startTLS bool) { fmt.Printf("TestMultiGoroutineSearch: starting...\n") - var l *ldap.Conn + var l *Conn var err error if TLS { - l, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true}) + l, err = DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true}) if err != nil { t.Errorf(err.Error()) return } defer l.Close() } else { - l, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + l, err = Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) if err != nil { t.Errorf(err.Error()) return @@ -223,9 +221,9 @@ func testMultiGoroutineSearch(t *testing.T, TLS bool, startTLS bool) { } } - results := make([]chan *ldap.SearchResult, len(filter)) + results := make([]chan *SearchResult, len(filter)) for i := range filter { - results[i] = make(chan *ldap.SearchResult) + results[i] = make(chan *SearchResult) go searchGoroutine(t, l, results[i], i) } for i := range filter { @@ -245,17 +243,17 @@ func TestMultiGoroutineSearch(t *testing.T) { } func TestEscapeFilter(t *testing.T) { - if got, want := ldap.EscapeFilter("a\x00b(c)d*e\\f"), `a\00b\28c\29d\2ae\5cf`; got != want { + if got, want := EscapeFilter("a\x00b(c)d*e\\f"), `a\00b\28c\29d\2ae\5cf`; got != want { t.Errorf("Got %s, expected %s", want, got) } - if got, want := ldap.EscapeFilter("Lučić"), `Lu\c4\8di\c4\87`; got != want { + if got, want := EscapeFilter("Lučić"), `Lu\c4\8di\c4\87`; got != want { t.Errorf("Got %s, expected %s", want, got) } } func TestCompare(t *testing.T) { fmt.Printf("TestCompare: starting...\n") - l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) if err != nil { t.Fatal(err.Error()) }