Skip to content

Commit

Permalink
Move LDAP client and config code to helper (#4532)
Browse files Browse the repository at this point in the history
  • Loading branch information
Becca Petrin authored May 10, 2018
1 parent 03a48f2 commit 8ea9efd
Show file tree
Hide file tree
Showing 10 changed files with 616 additions and 407 deletions.
260 changes: 10 additions & 250 deletions builtin/credential/ldap/backend.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
package ldap

import (
"bytes"
"context"
"fmt"
"math"
"strings"
"text/template"

"github.com/go-ldap/ldap"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/ldaputil"
"github.com/hashicorp/vault/helper/mfa"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
Expand Down Expand Up @@ -62,40 +58,6 @@ type backend struct {
*framework.Backend
}

func EscapeLDAPValue(input string) string {
// RFC4514 forbids un-escaped:
// - leading space or hash
// - trailing space
// - special characters '"', '+', ',', ';', '<', '>', '\\'
// - null
for i := 0; i < len(input); i++ {
escaped := false
if input[i] == '\\' {
i++
escaped = true
}
switch input[i] {
case '"', '+', ',', ';', '<', '>', '\\':
if !escaped {
input = input[0:i] + "\\" + input[i:]
i++
}
continue
}
if escaped {
input = input[0:i] + "\\" + input[i:]
i++
}
}
if input[0] == ' ' || input[0] == '#' {
input = "\\" + input
}
if input[len(input)-1] == ' ' {
input = input[0:len(input)-1] + "\\ "
}
return input
}

func (b *backend) Login(ctx context.Context, req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {

cfg, err := b.Config(ctx, req)
Expand All @@ -106,7 +68,12 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username stri
return nil, logical.ErrorResponse("ldap backend not configured"), nil, nil
}

c, err := cfg.DialLDAP()
ldapClient := ldaputil.Client{
Logger: b.Logger(),
LDAP: ldaputil.NewLDAP(),
}

c, err := ldapClient.DialLDAP(cfg)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil, nil
}
Expand All @@ -117,7 +84,7 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username stri
// Clean connection
defer c.Close()

userBindDN, err := b.getUserBindDN(cfg, c, username)
userBindDN, err := ldapClient.GetUserBindDN(cfg, c, username)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil, nil
}
Expand Down Expand Up @@ -151,12 +118,12 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username stri
}
}

userDN, err := b.getUserDN(cfg, c, userBindDN)
userDN, err := ldapClient.GetUserDN(cfg, c, userBindDN)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil, nil
}

ldapGroups, err := b.getLdapGroups(cfg, c, userDN, username)
ldapGroups, err := ldapClient.GetLdapGroups(cfg, c, userDN, username)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil, nil
}
Expand Down Expand Up @@ -227,213 +194,6 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username stri
return policies, ldapResponse, allGroups, nil
}

/*
* Parses a distinguished name and returns the CN portion.
* Given a non-conforming string (such as an already-extracted CN),
* it will be returned as-is.
*/
func (b *backend) getCN(dn string) string {
parsedDN, err := ldap.ParseDN(dn)
if err != nil || len(parsedDN.RDNs) == 0 {
// It was already a CN, return as-is
return dn
}

for _, rdn := range parsedDN.RDNs {
for _, rdnAttr := range rdn.Attributes {
if rdnAttr.Type == "CN" {
return rdnAttr.Value
}
}
}

// Default, return self
return dn
}

/*
* Discover and return the bind string for the user attempting to authenticate.
* This is handled in one of several ways:
*
* 1. If DiscoverDN is set, the user object will be searched for using userdn (base search path)
* and userattr (the attribute that maps to the provided username).
* The bind will either be anonymous or use binddn and bindpassword if they were provided.
* 2. If upndomain is set, the user dn is constructed as 'username@upndomain'. See https://msdn.microsoft.com/en-us/library/cc223499.aspx
*
*/
func (b *backend) getUserBindDN(cfg *ConfigEntry, c *ldap.Conn, username string) (string, error) {
bindDN := ""
if cfg.DiscoverDN || (cfg.BindDN != "" && cfg.BindPassword != "") {
var err error
if cfg.BindPassword != "" {
err = c.Bind(cfg.BindDN, cfg.BindPassword)
} else {
err = c.UnauthenticatedBind(cfg.BindDN)
}
if err != nil {
return bindDN, errwrap.Wrapf("LDAP bind (service) failed: {{err}}", err)
}

filter := fmt.Sprintf("(%s=%s)", cfg.UserAttr, ldap.EscapeFilter(username))
if b.Logger().IsDebug() {
b.Logger().Debug("discovering user", "userdn", cfg.UserDN, "filter", filter)
}
result, err := c.Search(&ldap.SearchRequest{
BaseDN: cfg.UserDN,
Scope: 2, // subtree
Filter: filter,
SizeLimit: math.MaxInt32,
})
if err != nil {
return bindDN, errwrap.Wrapf("LDAP search for binddn failed: {{err}}", err)
}
if len(result.Entries) != 1 {
return bindDN, fmt.Errorf("LDAP search for binddn 0 or not unique")
}
bindDN = result.Entries[0].DN
} else {
if cfg.UPNDomain != "" {
bindDN = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain)
} else {
bindDN = fmt.Sprintf("%s=%s,%s", cfg.UserAttr, EscapeLDAPValue(username), cfg.UserDN)
}
}

return bindDN, nil
}

/*
* Returns the DN of the object representing the authenticated user.
*/
func (b *backend) getUserDN(cfg *ConfigEntry, c *ldap.Conn, bindDN string) (string, error) {
userDN := ""
if cfg.UPNDomain != "" {
// Find the distinguished name for the user if userPrincipalName used for login
filter := fmt.Sprintf("(userPrincipalName=%s)", ldap.EscapeFilter(bindDN))
if b.Logger().IsDebug() {
b.Logger().Debug("searching upn", "userdn", cfg.UserDN, "filter", filter)
}
result, err := c.Search(&ldap.SearchRequest{
BaseDN: cfg.UserDN,
Scope: 2, // subtree
Filter: filter,
SizeLimit: math.MaxInt32,
})
if err != nil {
return userDN, errwrap.Wrapf("LDAP search failed for detecting user: {{err}}", err)
}
for _, e := range result.Entries {
userDN = e.DN
}
} else {
userDN = bindDN
}

return userDN, nil
}

/*
* getLdapGroups queries LDAP and returns a slice describing the set of groups the authenticated user is a member of.
*
* The search query is constructed according to cfg.GroupFilter, and run in context of cfg.GroupDN.
* Groups will be resolved from the query results by following the attribute defined in cfg.GroupAttr.
*
* cfg.GroupFilter is a go template and is compiled with the following context: [UserDN, Username]
* UserDN - The DN of the authenticated user
* Username - The Username of the authenticated user
*
* Example:
* cfg.GroupFilter = "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
* cfg.GroupDN = "OU=Groups,DC=myorg,DC=com"
* cfg.GroupAttr = "cn"
*
* NOTE - If cfg.GroupFilter is empty, no query is performed and an empty result slice is returned.
*
*/
func (b *backend) getLdapGroups(cfg *ConfigEntry, c *ldap.Conn, userDN string, username string) ([]string, error) {
// retrieve the groups in a string/bool map as a structure to avoid duplicates inside
ldapMap := make(map[string]bool)

if cfg.GroupFilter == "" {
b.Logger().Warn("groupfilter is empty, will not query server")
return make([]string, 0), nil
}

if cfg.GroupDN == "" {
b.Logger().Warn("groupdn is empty, will not query server")
return make([]string, 0), nil
}

// If groupfilter was defined, resolve it as a Go template and use the query for
// returning the user's groups
if b.Logger().IsDebug() {
b.Logger().Debug("compiling group filter", "group_filter", cfg.GroupFilter)
}

// Parse the configuration as a template.
// Example template "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
t, err := template.New("queryTemplate").Parse(cfg.GroupFilter)
if err != nil {
return nil, errwrap.Wrapf("LDAP search failed due to template compilation error: {{err}}", err)
}

// Build context to pass to template - we will be exposing UserDn and Username.
context := struct {
UserDN string
Username string
}{
ldap.EscapeFilter(userDN),
ldap.EscapeFilter(username),
}

var renderedQuery bytes.Buffer
t.Execute(&renderedQuery, context)

if b.Logger().IsDebug() {
b.Logger().Debug("searching", "groupdn", cfg.GroupDN, "rendered_query", renderedQuery.String())
}

result, err := c.Search(&ldap.SearchRequest{
BaseDN: cfg.GroupDN,
Scope: 2, // subtree
Filter: renderedQuery.String(),
Attributes: []string{
cfg.GroupAttr,
},
SizeLimit: math.MaxInt32,
})
if err != nil {
return nil, errwrap.Wrapf("LDAP search failed: {{err}}", err)
}

for _, e := range result.Entries {
dn, err := ldap.ParseDN(e.DN)
if err != nil || len(dn.RDNs) == 0 {
continue
}

// Enumerate attributes of each result, parse out CN and add as group
values := e.GetAttributeValues(cfg.GroupAttr)
if len(values) > 0 {
for _, val := range values {
groupCN := b.getCN(val)
ldapMap[groupCN] = true
}
} else {
// If groupattr didn't resolve, use self (enumerating group objects)
groupCN := b.getCN(e.DN)
ldapMap[groupCN] = true
}
}

ldapGroups := make([]string, 0, len(ldapMap))
for key, _ := range ldapMap {
ldapGroups = append(ldapGroups, key)
}

return ldapGroups, nil
}

const backendHelp = `
The "ldap" credential provider allows authentication querying
a LDAP server, checking username and password, and associating groups
Expand Down
17 changes: 0 additions & 17 deletions builtin/credential/ldap/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,23 +654,6 @@ func testAccStepLoginNoGroupDN(t *testing.T, user string, pass string) logicalte
}
}

func TestLDAPEscape(t *testing.T) {
testcases := map[string]string{
"#test": "\\#test",
"test,hello": "test\\,hello",
"test,hel+lo": "test\\,hel\\+lo",
"test\\hello": "test\\\\hello",
" test ": "\\ test \\ ",
}

for test, answer := range testcases {
res := EscapeLDAPValue(test)
if res != answer {
t.Errorf("Failed to escape %s: %s != %s\n", test, res, answer)
}
}
}

func testAccStepGroupList(t *testing.T, groups []string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ListOperation,
Expand Down
Loading

0 comments on commit 8ea9efd

Please sign in to comment.