Skip to content

Commit

Permalink
fix: support set.add on nil sets in traits expression parser (#49429)
Browse files Browse the repository at this point in the history
  • Loading branch information
nklaassen authored Nov 26, 2024
1 parent e95594d commit c1c234f
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 25 deletions.
6 changes: 4 additions & 2 deletions lib/expression/dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ func (d Dict) addValues(key string, values ...string) Dict {
out[key] = NewSet(values...)
return out
}
s.Add(values...)
// Calling s.add would do an unnecessary extra copy since we already
// cloned the whole Dict. s.s.Add adds to the existing cloned set.
s.s.Add(values...)
return out
}

Expand All @@ -71,7 +73,7 @@ func (d Dict) remove(keys ...string) any {
func (d Dict) clone() Dict {
out := make(Dict, len(d))
for key, set := range d {
out[key] = Set{set.Clone()}
out[key] = set.clone()
}
return out
}
Expand Down
66 changes: 60 additions & 6 deletions lib/expression/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ import (
"github.com/gravitational/teleport/lib/utils/typical"
)

func TestEvaluateTraitsMap(t *testing.T) {
t.Parallel()

baseInputTraits := map[string][]string{
var (
baseInputTraits = map[string][]string{
"groups": []string{"devs", "security"},
"username": []string{"alice"},
}

tests := []struct {
testCases = []struct {
desc string
expressions map[string][]string
inputTraits map[string][]string
Expand Down Expand Up @@ -253,7 +251,31 @@ func TestEvaluateTraitsMap(t *testing.T) {
"localEmails": {"alice", "bob", "charlie", "darrell", "esther", "frank"},
},
},
{
desc: "methods on nil set from nonexistent map key",
expressions: map[string][]string{
"a": {`user.spec.traits["a"].add("a")`},
"b": {`ifelse(user.spec.traits["b"].contains("b"), set("z"), set("b"))`},
"c": {`ifelse(user.spec.traits["c"].contains_any(set("c")), set("z"), set("c"))`},
"d": {`ifelse(user.spec.traits["d"].isempty(), set("d"), set("z"))`},
"e": {`user.spec.traits["e"].remove("e")`},
"f": {`user.spec.traits["f"].remove("f").add("f")`},
},
inputTraits: baseInputTraits,
expectedTraits: map[string][]string{
"a": {"a"},
"b": {"b"},
"c": {"c"},
"d": {"d"},
"e": {},
"f": {"f"},
},
},
}
)

func TestEvaluateTraitsMap(t *testing.T) {
t.Parallel()

type evaluationEnv struct {
Traits Dict
Expand All @@ -270,7 +292,7 @@ func TestEvaluateTraitsMap(t *testing.T) {
attributeParser, err := NewTraitsExpressionParser[evaluationEnv](typicalEnvVar)
require.NoError(t, err)

for _, tc := range tests {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
result, err := EvaluateTraitsMap[evaluationEnv](
evaluationEnv{
Expand All @@ -292,3 +314,35 @@ func TestEvaluateTraitsMap(t *testing.T) {
})
}
}

func FuzzTraitsExpressionParser(f *testing.F) {
type evaluationEnv struct {
Traits Dict
}
parser, err := NewTraitsExpressionParser[evaluationEnv](map[string]typical.Variable{
"true": true,
"false": false,
"user.spec.traits": typical.DynamicMap[evaluationEnv, Set](func(env evaluationEnv) (Dict, error) {
return env.Traits, nil
}),
})
require.NoError(f, err)
for _, tc := range testCases {
for _, expressions := range tc.expressions {
for _, expression := range expressions {
f.Add(expression)
}
}
}
f.Fuzz(func(t *testing.T, expression string) {
expr, err := parser.Parse(expression)
if err != nil {
// Many/most fuzzed expressions won't parse, as long as we didn't
// panic that's okay.
return
}
// If the expression parsed, try to evaluate it, errors are okay just
// make sure we don't panic.
_, _ = expr.Evaluate(evaluationEnv{DictFromStringSliceMap(baseInputTraits)})
})
}
26 changes: 13 additions & 13 deletions lib/expression/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ func DefaultParserSpec[evaluationEnv any]() typical.ParserSpec[evaluationEnv] {
}),
"email.local": typical.UnaryFunction[evaluationEnv](
func(emails Set) (Set, error) {
locals, err := parse.EmailLocal(emails.Elements())
locals, err := parse.EmailLocal(emails.items())
if err != nil {
return Set{}, trace.Wrap(err)
}
return NewSet(locals...), nil
}),
"regexp.replace": typical.TernaryFunction[evaluationEnv](
func(inputs Set, match string, replacement string) (Set, error) {
replaced, err := parse.RegexpReplace(inputs.Elements(), match, replacement)
replaced, err := parse.RegexpReplace(inputs.items(), match, replacement)
if err != nil {
return Set{}, trace.Wrap(err)
}
Expand All @@ -101,7 +101,7 @@ func DefaultParserSpec[evaluationEnv any]() typical.ParserSpec[evaluationEnv] {
"strings.split": typical.BinaryFunction[evaluationEnv](
func(inputs Set, sep string) (Set, error) {
var outputs []string
for input := range inputs.Set {
for input := range inputs.s {
outputs = append(outputs, strings.Split(input, sep)...)
}
return NewSet(outputs...), nil
Expand Down Expand Up @@ -131,26 +131,26 @@ func DefaultParserSpec[evaluationEnv any]() typical.ParserSpec[evaluationEnv] {
}),
"contains_any": typical.BinaryFunction[evaluationEnv](
func(s1, s2 Set) (bool, error) {
for v := range s2.Set {
if s1.Contains(v) {
for v := range s2.s {
if s1.contains(v) {
return true, nil
}
}
return false, nil
}),
"is_empty": typical.UnaryFunction[evaluationEnv](
func(s Set) (bool, error) {
return len(s.Set) == 0, nil
return len(s.s) == 0, nil
}),
},
Methods: map[string]typical.Function{
"add": typical.BinaryVariadicFunction[evaluationEnv](
func(s Set, values ...string) (Set, error) {
return Set{s.Clone().Add(values...)}, nil
return s.add(values...), nil
}),
"contains": typical.BinaryFunction[evaluationEnv](
func(s Set, str string) (bool, error) {
return s.Contains(str), nil
return s.contains(str), nil
}),
"put": typical.TernaryFunction[evaluationEnv](
func(d Dict, key string, value Set) (Dict, error) {
Expand Down Expand Up @@ -185,16 +185,16 @@ func DefaultParserSpec[evaluationEnv any]() typical.ParserSpec[evaluationEnv] {
}),
"contains_any": typical.BinaryFunction[evaluationEnv](
func(s1, s2 Set) (bool, error) {
for v := range s2.Set {
if s1.Contains(v) {
for v := range s2.s {
if s1.contains(v) {
return true, nil
}
}
return false, nil
}),
"isempty": typical.UnaryFunction[evaluationEnv](
func(s Set) (bool, error) {
return len(s.Set) == 0, nil
return len(s.s) == 0, nil
}),
},
}
Expand Down Expand Up @@ -229,7 +229,7 @@ func traitsMapResultToSet(result any, expr string) (Set, error) {
func StringSliceMapFromDict(d Dict) map[string][]string {
m := make(map[string][]string, len(d))
for key, s := range d {
m[key] = s.Elements()
m[key] = s.items()
}
return m
}
Expand All @@ -249,7 +249,7 @@ func StringTransform(name string, input any, f func(string) string) (any, error)
case string:
return f(typedInput), nil
case Set:
return Set{utils.SetTransform(typedInput.Set, f)}, nil
return Set{utils.SetTransform(typedInput.s, f)}, nil
default:
return nil, trace.BadParameter("failed to evaluate argument to %s: expected string or set, got value of type %T", name, input)
}
Expand Down
33 changes: 29 additions & 4 deletions lib/expression/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,56 @@

package expression

import "github.com/gravitational/teleport/lib/utils"
import (
"github.com/gravitational/teleport/lib/utils"
)

// Set is a map of type string key and struct values. Set is a thin wrapper over
// the utils.Set[T] generic set type, allowing the Set to implement the
// interface(s) required for use with the expression package.
// The default value is an empty set and all methods are safe to call (even if
// the underlying map is nil).
type Set struct {
utils.Set[string]
s utils.Set[string]
}

// NewSet constructs a new set from an arbitrary collection of elements
func NewSet(values ...string) Set {
return Set{utils.NewSet(values...)}
}

// add creates a new Set containing all values in the receiver Set and adds
// [elements].
func (s Set) add(elements ...string) Set {
if s.s == nil {
return NewSet(elements...)
}
return Set{s.s.Clone().Add(elements...)}
}

// remove creates a new Set containing all values in the receiver Set, minus
// all supplied elements. Implements expression.Remover for Set.
func (s Set) remove(elements ...string) any {
return Set{s.Set.Clone().Remove(elements...)}
return Set{s.s.Clone().Remove(elements...)}
}

func (s Set) contains(element string) bool {
return s.s.Contains(element)
}

func (s Set) clone() Set {
return Set{s.s.Clone()}
}

func (s Set) items() []string {
return s.s.Elements()
}

// union computes the union of multiple sets
func union(sets ...Set) Set {
result := utils.NewSet[string]()
for _, set := range sets {
result.Union(set.Set)
result.Union(set.s)
}
return Set{result}
}

0 comments on commit c1c234f

Please sign in to comment.