Skip to content

Commit

Permalink
Merge pull request #126 from tiziano88/check_empty_password
Browse files Browse the repository at this point in the history
Require explicit intention for empty password.
  • Loading branch information
johnweldon authored Aug 24, 2017
2 parents 37f35d7 + bb09d4b commit 95ede12
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 76 deletions.
80 changes: 37 additions & 43 deletions bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package ldap
import (
"errors"

"gopkg.in/asn1-ber.v1"
ber "gopkg.in/asn1-ber.v1"
)

// SimpleBindRequest represents a username/password bind operation
Expand All @@ -18,6 +18,9 @@ type SimpleBindRequest struct {
Password string
// Controls are optional controls to send with the bind request
Controls []Control
// AllowEmptyPassword sets whether the client allows binding with an empty password
// (normally used for unauthenticated bind).
AllowEmptyPassword bool
}

// SimpleBindResult contains the response from the server
Expand All @@ -28,9 +31,10 @@ type SimpleBindResult struct {
// NewSimpleBindRequest returns a bind request
func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest {
return &SimpleBindRequest{
Username: username,
Password: password,
Controls: controls,
Username: username,
Password: password,
Controls: controls,
AllowEmptyPassword: false,
}
}

Expand All @@ -47,6 +51,10 @@ func (bindRequest *SimpleBindRequest) encode() *ber.Packet {

// SimpleBind performs the simple bind operation defined in the given request
func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) {
if simpleBindRequest.Password == "" && !simpleBindRequest.AllowEmptyPassword {
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
}

packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
encodedBindRequest := simpleBindRequest.encode()
Expand Down Expand Up @@ -97,47 +105,33 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu
return result, nil
}

// Bind performs a bind with the given username and password
// Bind performs a bind with the given username and password.
//
// It does not allow unauthenticated bind (i.e. empty password). Use the UnauthenticatedBind method
// for that.
func (l *Conn) Bind(username, password string) error {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
bindRequest.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, username, "User Name"))
bindRequest.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, password, "Password"))
packet.AppendChild(bindRequest)

if l.Debug {
ber.PrintPacket(packet)
}

msgCtx, err := l.sendMessage(packet)
if err != nil {
return err
}
defer l.finishMessage(msgCtx)

packetResponse, ok := <-msgCtx.responses
if !ok {
return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
}
packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil {
return err
}

if l.Debug {
if err := addLDAPDescriptions(packet); err != nil {
return err
}
ber.PrintPacket(packet)
req := &SimpleBindRequest{
Username: username,
Password: password,
AllowEmptyPassword: false,
}
_, err := l.SimpleBind(req)
return err
}

resultCode, resultDescription := getLDAPResultCode(packet)
if resultCode != 0 {
return NewError(resultCode, errors.New(resultDescription))
// UnauthenticatedBind performs an unauthenticated bind.
//
// A username may be provided for trace (e.g. logging) purpose only, but it is normally not
// authenticated or otherwise validated by the LDAP server.
//
// See https://tools.ietf.org/html/rfc4513#section-5.1.2 .
// See https://tools.ietf.org/html/rfc4513#section-6.3.1 .
func (l *Conn) UnauthenticatedBind(username string) error {
req := &SimpleBindRequest{
Username: username,
Password: "",
AllowEmptyPassword: true,
}

return nil
_, err := l.SimpleBind(req)
return err
}
2 changes: 2 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ const (
ErrorDebugging = 203
ErrorUnexpectedMessage = 204
ErrorUnexpectedResponse = 205
ErrorEmptyPassword = 206
)

// LDAPResultCodeMap contains string descriptions for LDAP error codes
Expand Down Expand Up @@ -104,6 +105,7 @@ var LDAPResultCodeMap = map[uint8]string{
ErrorDebugging: "Debugging Error",
ErrorUnexpectedMessage: "Unexpected Message",
ErrorUnexpectedResponse: "Unexpected Response",
ErrorEmptyPassword: "Empty password not allowed by the client",
}

func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) {
Expand Down
64 changes: 31 additions & 33 deletions ldap_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package ldap_test
package ldap

import (
"crypto/tls"
"fmt"
"testing"

"gopkg.in/ldap.v2"
)

var ldapServer = "ldap.itd.umich.edu"
Expand All @@ -23,7 +21,7 @@ var attributes = []string{

func TestDial(t *testing.T) {
fmt.Printf("TestDial: starting...\n")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
Expand All @@ -34,7 +32,7 @@ func TestDial(t *testing.T) {

func TestDialTLS(t *testing.T) {
fmt.Printf("TestDialTLS: starting...\n")
l, err := ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true})
l, err := DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Errorf(err.Error())
return
Expand All @@ -45,7 +43,7 @@ func TestDialTLS(t *testing.T) {

func TestStartTLS(t *testing.T) {
fmt.Printf("TestStartTLS: starting...\n")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
Expand All @@ -60,16 +58,16 @@ func TestStartTLS(t *testing.T) {

func TestSearch(t *testing.T) {
fmt.Printf("TestSearch: starting...\n")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()

searchRequest := ldap.NewSearchRequest(
searchRequest := NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[0],
attributes,
nil)
Expand All @@ -85,16 +83,16 @@ func TestSearch(t *testing.T) {

func TestSearchStartTLS(t *testing.T) {
fmt.Printf("TestSearchStartTLS: starting...\n")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()

searchRequest := ldap.NewSearchRequest(
searchRequest := NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[0],
attributes,
nil)
Expand Down Expand Up @@ -125,22 +123,22 @@ func TestSearchStartTLS(t *testing.T) {

func TestSearchWithPaging(t *testing.T) {
fmt.Printf("TestSearchWithPaging: starting...\n")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()

err = l.Bind("", "")
err = l.UnauthenticatedBind("")
if err != nil {
t.Errorf(err.Error())
return
}

searchRequest := ldap.NewSearchRequest(
searchRequest := NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
nil)
Expand All @@ -152,12 +150,12 @@ func TestSearchWithPaging(t *testing.T) {

fmt.Printf("TestSearchWithPaging: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries))

searchRequest = ldap.NewSearchRequest(
searchRequest = NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
[]ldap.Control{ldap.NewControlPaging(5)})
[]Control{NewControlPaging(5)})
sr, err = l.SearchWithPaging(searchRequest, 5)
if err != nil {
t.Errorf(err.Error())
Expand All @@ -166,23 +164,23 @@ func TestSearchWithPaging(t *testing.T) {

fmt.Printf("TestSearchWithPaging: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries))

searchRequest = ldap.NewSearchRequest(
searchRequest = NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
[]ldap.Control{ldap.NewControlPaging(500)})
[]Control{NewControlPaging(500)})
sr, err = l.SearchWithPaging(searchRequest, 5)
if err == nil {
t.Errorf("expected an error when paging size in control in search request doesn't match size given in call, got none")
return
}
}

func searchGoroutine(t *testing.T, l *ldap.Conn, results chan *ldap.SearchResult, i int) {
searchRequest := ldap.NewSearchRequest(
func searchGoroutine(t *testing.T, l *Conn, results chan *SearchResult, i int) {
searchRequest := NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[i],
attributes,
nil)
Expand All @@ -197,17 +195,17 @@ func searchGoroutine(t *testing.T, l *ldap.Conn, results chan *ldap.SearchResult

func testMultiGoroutineSearch(t *testing.T, TLS bool, startTLS bool) {
fmt.Printf("TestMultiGoroutineSearch: starting...\n")
var l *ldap.Conn
var l *Conn
var err error
if TLS {
l, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true})
l, err = DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()
} else {
l, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err = Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
Expand All @@ -223,9 +221,9 @@ func testMultiGoroutineSearch(t *testing.T, TLS bool, startTLS bool) {
}
}

results := make([]chan *ldap.SearchResult, len(filter))
results := make([]chan *SearchResult, len(filter))
for i := range filter {
results[i] = make(chan *ldap.SearchResult)
results[i] = make(chan *SearchResult)
go searchGoroutine(t, l, results[i], i)
}
for i := range filter {
Expand All @@ -245,17 +243,17 @@ func TestMultiGoroutineSearch(t *testing.T) {
}

func TestEscapeFilter(t *testing.T) {
if got, want := ldap.EscapeFilter("a\x00b(c)d*e\\f"), `a\00b\28c\29d\2ae\5cf`; got != want {
if got, want := EscapeFilter("a\x00b(c)d*e\\f"), `a\00b\28c\29d\2ae\5cf`; got != want {
t.Errorf("Got %s, expected %s", want, got)
}
if got, want := ldap.EscapeFilter("Lučić"), `Lu\c4\8di\c4\87`; got != want {
if got, want := EscapeFilter("Lučić"), `Lu\c4\8di\c4\87`; got != want {
t.Errorf("Got %s, expected %s", want, got)
}
}

func TestCompare(t *testing.T) {
fmt.Printf("TestCompare: starting...\n")
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Fatal(err.Error())
}
Expand Down

0 comments on commit 95ede12

Please sign in to comment.