Skip to content

Commit

Permalink
fix: memory allocations (#252)
Browse files Browse the repository at this point in the history
* fix: memory allocations

* fix: replace strings.Builder to bytes.Buffer for old goland versions

* fix: rename escapedStringToEncodedBytes to decodeEscapedSymbols

* feat: remove one allocation

* fix (v3): memory allocations
  • Loading branch information
khevse authored and johnweldon committed Dec 14, 2019
1 parent a75d3c9 commit a4f79d8
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 156 deletions.
176 changes: 99 additions & 77 deletions filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ import (
hexpac "encoding/hex"
"errors"
"fmt"
"io"
"strings"
"unicode"
"unicode/utf8"

"github.com/go-asn1-ber/asn1-ber"
ber "github.com/go-asn1-ber/asn1-ber"
)

// Filter choices
Expand Down Expand Up @@ -69,6 +71,8 @@ var MatchingRuleAssertionMap = map[uint64]string{
MatchingRuleAssertionDNAttributes: "Matching Rule Assertion DN Attributes",
}

var _SymbolAny = []byte{'*'}

// CompileFilter converts a string representation of a filter into a BER-encoded packet
func CompileFilter(filter string) (*ber.Packet, error) {
if len(filter) == 0 || filter[0] != '(' {
Expand All @@ -88,74 +92,75 @@ func CompileFilter(filter string) (*ber.Packet, error) {
}

// DecompileFilter converts a packet representation of a filter into a string representation
func DecompileFilter(packet *ber.Packet) (ret string, err error) {
func DecompileFilter(packet *ber.Packet) (_ string, err error) {
defer func() {
if r := recover(); r != nil {
err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter"))
}
}()
ret = "("
err = nil

buf := bytes.NewBuffer(nil)
buf.WriteByte('(')
childStr := ""

switch packet.Tag {
case FilterAnd:
ret += "&"
buf.WriteByte('&')
for _, child := range packet.Children {
childStr, err = DecompileFilter(child)
if err != nil {
return
}
ret += childStr
buf.WriteString(childStr)
}
case FilterOr:
ret += "|"
buf.WriteByte('|')
for _, child := range packet.Children {
childStr, err = DecompileFilter(child)
if err != nil {
return
}
ret += childStr
buf.WriteString(childStr)
}
case FilterNot:
ret += "!"
buf.WriteByte('!')
childStr, err = DecompileFilter(packet.Children[0])
if err != nil {
return
}
ret += childStr
buf.WriteString(childStr)

case FilterSubstrings:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "="
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
buf.WriteByte('=')
for i, child := range packet.Children[1].Children {
if i == 0 && child.Tag != FilterSubstringsInitial {
ret += "*"
buf.Write(_SymbolAny)
}
ret += EscapeFilter(ber.DecodeString(child.Data.Bytes()))
buf.WriteString(EscapeFilter(ber.DecodeString(child.Data.Bytes())))
if child.Tag != FilterSubstringsFinal {
ret += "*"
buf.Write(_SymbolAny)
}
}
case FilterEqualityMatch:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "="
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
buf.WriteByte('=')
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
case FilterGreaterOrEqual:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += ">="
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
buf.WriteString(">=")
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
case FilterLessOrEqual:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "<="
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
buf.WriteString("<=")
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
case FilterPresent:
ret += ber.DecodeString(packet.Data.Bytes())
ret += "=*"
buf.WriteString(ber.DecodeString(packet.Data.Bytes()))
buf.WriteString("=*")
case FilterApproxMatch:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "~="
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
buf.WriteString("~=")
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
case FilterExtensibleMatch:
attr := ""
dnAttributes := false
Expand All @@ -176,21 +181,22 @@ func DecompileFilter(packet *ber.Packet) (ret string, err error) {
}

if len(attr) > 0 {
ret += attr
buf.WriteString(attr)
}
if dnAttributes {
ret += ":dn"
buf.WriteString(":dn")
}
if len(matchingRule) > 0 {
ret += ":"
ret += matchingRule
buf.WriteString(":")
buf.WriteString(matchingRule)
}
ret += ":="
ret += EscapeFilter(value)
buf.WriteString(":=")
buf.WriteString(EscapeFilter(value))
}

ret += ")"
return
buf.WriteByte(')')

return buf.String(), nil
}

func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) {
Expand Down Expand Up @@ -253,11 +259,10 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
)

state := stateReadingAttr

attribute := ""
attribute := bytes.NewBuffer(nil)
extensibleDNAttributes := false
extensibleMatchingRule := ""
condition := ""
extensibleMatchingRule := bytes.NewBuffer(nil)
condition := bytes.NewBuffer(nil)

for newPos < len(filter) {
remainingFilter := filter[newPos:]
Expand Down Expand Up @@ -324,7 +329,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {

// Still reading the attribute name
default:
attribute += fmt.Sprintf("%c", currentRune)
attribute.WriteRune(currentRune)
newPos += currentWidth
}

Expand All @@ -338,13 +343,13 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {

// Still reading the matching rule oid
default:
extensibleMatchingRule += fmt.Sprintf("%c", currentRune)
extensibleMatchingRule.WriteRune(currentRune)
newPos += currentWidth
}

case stateReadingCondition:
// append to the condition
condition += fmt.Sprintf("%c", currentRune)
condition.WriteRune(currentRune)
newPos += currentWidth
}
}
Expand All @@ -368,17 +373,17 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
// }

// Include the matching rule oid, if specified
if len(extensibleMatchingRule) > 0 {
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule, MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule]))
if extensibleMatchingRule.Len() > 0 {
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule.String(), MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule]))
}

// Include the attribute, if specified
if len(attribute) > 0 {
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute, MatchingRuleAssertionMap[MatchingRuleAssertionType]))
if attribute.Len() > 0 {
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute.String(), MatchingRuleAssertionMap[MatchingRuleAssertionType]))
}

// Add the value (only required child)
encodedString, encodeErr := escapedStringToEncodedBytes(condition)
encodedString, encodeErr := decodeEscapedSymbols(condition.Bytes())
if encodeErr != nil {
return packet, newPos, encodeErr
}
Expand All @@ -389,16 +394,16 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
packet.AppendChild(ber.NewBoolean(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionDNAttributes, extensibleDNAttributes, MatchingRuleAssertionMap[MatchingRuleAssertionDNAttributes]))
}

case packet.Tag == FilterEqualityMatch && condition == "*":
packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute, FilterMap[FilterPresent])
case packet.Tag == FilterEqualityMatch && strings.Contains(condition, "*"):
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
case packet.Tag == FilterEqualityMatch && bytes.Equal(condition.Bytes(), _SymbolAny):
packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute.String(), FilterMap[FilterPresent])
case packet.Tag == FilterEqualityMatch && bytes.Index(condition.Bytes(), _SymbolAny) > -1:
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute.String(), "Attribute"))
packet.Tag = FilterSubstrings
packet.Description = FilterMap[uint64(packet.Tag)]
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
parts := strings.Split(condition, "*")
parts := bytes.Split(condition.Bytes(), _SymbolAny)
for i, part := range parts {
if part == "" {
if len(part) == 0 {
continue
}
var tag ber.Tag
Expand All @@ -410,19 +415,19 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
default:
tag = FilterSubstringsAny
}
encodedString, encodeErr := escapedStringToEncodedBytes(part)
encodedString, encodeErr := decodeEscapedSymbols(part)
if encodeErr != nil {
return packet, newPos, encodeErr
}
seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, tag, encodedString, FilterSubstringsMap[uint64(tag)]))
}
packet.AppendChild(seq)
default:
encodedString, encodeErr := escapedStringToEncodedBytes(condition)
encodedString, encodeErr := decodeEscapedSymbols(condition.Bytes())
if encodeErr != nil {
return packet, newPos, encodeErr
}
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute.String(), "Attribute"))
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, encodedString, "Condition"))
}

Expand All @@ -432,34 +437,51 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
}

// Convert from "ABC\xx\xx\xx" form to literal bytes for transport
func escapedStringToEncodedBytes(escapedString string) (string, error) {
var buffer bytes.Buffer
i := 0
for i < len(escapedString) {
currentRune, currentWidth := utf8.DecodeRuneInString(escapedString[i:])
if currentRune == utf8.RuneError {
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", i))
func decodeEscapedSymbols(src []byte) (string, error) {

var (
buffer bytes.Buffer
offset int
reader = bytes.NewReader(src)
byteHex []byte
byteVal []byte
)

for {
runeVal, runeSize, err := reader.ReadRune()
if err == io.EOF {
return buffer.String(), nil
} else if err != nil {
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: failed to read filter: %v", err))
} else if runeVal == unicode.ReplacementChar {
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", offset))
}

// Check for escaped hex characters and convert them to their literal value for transport.
if currentRune == '\\' {
if runeVal == '\\' {
// http://tools.ietf.org/search/rfc4515
// \ (%x5C) is not a valid character unless it is followed by two HEX characters due to not
// being a member of UTF1SUBSET.
if i+2 > len(escapedString) {
return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter"))
if byteHex == nil {
byteHex = make([]byte, 2)
byteVal = make([]byte, 1)
}

if _, err := io.ReadFull(reader, byteHex); err != nil {
if err == io.ErrUnexpectedEOF {
return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter"))
}
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: invalid characters for escape in filter: %v", err))
}
escByte, decodeErr := hexpac.DecodeString(escapedString[i+1 : i+3])
if decodeErr != nil {
return "", NewError(ErrorFilterCompile, errors.New("ldap: invalid characters for escape in filter"))

if _, err := hexpac.Decode(byteVal, byteHex); err != nil {
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: invalid characters for escape in filter: %v", err))
}
buffer.WriteByte(escByte[0])
i += 2 // +1 from end of loop, so 3 total for \xx.

buffer.Write(byteVal)
} else {
buffer.WriteRune(currentRune)
buffer.WriteRune(runeVal)
}

i += currentWidth
offset += runeSize
}
return buffer.String(), nil
}
40 changes: 39 additions & 1 deletion filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"strings"
"testing"

"github.com/go-asn1-ber/asn1-ber"
ber "github.com/go-asn1-ber/asn1-ber"
)

type compileTest struct {
Expand Down Expand Up @@ -213,6 +213,44 @@ func TestFilter(t *testing.T) {
}
}

func TestDecodeEscapedSymbols(t *testing.T) {

for _, testInfo := range []struct {
Src string
Err string
}{
{
Src: "a\u0100\x80",
Err: `LDAP Result Code 201 "Filter Compile Error": ldap: error reading rune at position 3`,
},
{
Src: `start\d`,
Err: `LDAP Result Code 201 "Filter Compile Error": ldap: missing characters for escape in filter`,
},
{
Src: `\`,
Err: `LDAP Result Code 201 "Filter Compile Error": ldap: invalid characters for escape in filter: EOF`,
},
{
Src: `start\--end`,
Err: `LDAP Result Code 201 "Filter Compile Error": ldap: invalid characters for escape in filter: encoding/hex: invalid byte: U+002D '-'`,
},
{
Src: `start\d0\hh`,
Err: `LDAP Result Code 201 "Filter Compile Error": ldap: invalid characters for escape in filter: encoding/hex: invalid byte: U+0068 'h'`,
},
} {

res, err := decodeEscapedSymbols([]byte(testInfo.Src))
if err == nil || err.Error() != testInfo.Err {
t.Fatal(testInfo.Src, "=> ", err, "!=", testInfo.Err)
}
if res != "" {
t.Fatal(testInfo.Src, "=> ", "invalid result", res)
}
}
}

func TestInvalidFilter(t *testing.T) {
for _, filterStr := range testInvalidFilters {
if _, err := CompileFilter(filterStr); err == nil {
Expand Down
Loading

0 comments on commit a4f79d8

Please sign in to comment.