diff --git a/ast/check.go b/ast/check.go index 57652686b7..edba7b729d 100644 --- a/ast/check.go +++ b/ast/check.go @@ -1129,31 +1129,29 @@ func getObjectType(ref Ref, o types.Type, rule *Rule, d *types.DynamicProperty) return getObjectTypeRec(keys, o, d), nil } -func getRuleAnnotation(rule *Rule) (sannots []SchemaAnnotation) { - for _, annot := range rule.Module.Annotation { - schemaAnnots, ok := annot.(*SchemaAnnotations) - if ok && schemaAnnots.Scope == ruleScope && schemaAnnots.Rule == rule { - return schemaAnnots.SchemaAnnotation +func getRuleAnnotation(rule *Rule) (result []*SchemaAnnotation) { + for _, a := range rule.Module.Annotations { + other, ok := a.Node.(*Rule) + if !ok { + continue + } + if other == rule { + result = append(result, a.Schemas...) } } - return nil + return result } // NOTE: Currently, annotations must preceed the rule. In the future, this // restriction could be relaxed with other kinds of annotation scopes. -func processAnnotation(ss *SchemaSet, annot SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, types.Type, *Error) { +func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, types.Type, *Error) { if ss == nil { return nil, nil, NewError(TypeErr, rule.Location, "schemas need to be supplied for the annotation: %s", annot.Schema) } - schemaRef, err := ParseRef(annot.Schema) - if err != nil { - return nil, nil, NewError(TypeErr, rule.Location, "schema is not well formed in annotation: %s", annot.Schema) - } - - schema := ss.Get(schemaRef) + schema := ss.Get(annot.Schema) if schema == nil { - return nil, nil, NewError(TypeErr, rule.Location, "schema does not exist for given path in annotation: %s", schemaRef.String()) + return nil, nil, NewError(TypeErr, rule.Location, "schema does not exist for given path in annotation: %s", annot.Schema) } tpe, err := loadSchema(schema) @@ -1161,10 +1159,5 @@ func processAnnotation(ss *SchemaSet, annot SchemaAnnotation, env *TypeEnv, rule return nil, nil, NewError(TypeErr, rule.Location, err.Error()) } - ref, err := ParseRef(annot.Path) - if err != nil { - return nil, nil, NewError(TypeErr, rule.Location, err.Error()) - } - - return ref, tpe, nil + return annot.Path, tpe, nil } diff --git a/ast/check_test.go b/ast/check_test.go index 4dbe7a3487..07279c32cd 100644 --- a/ast/check_test.go +++ b/ast/check_test.go @@ -1446,57 +1446,6 @@ default allow = false # scope: rule # schemas: # - input: schema["badpath"] -whocan[user] { - access = acl[user] - access[_] == input.operation -}` - - module5 := ` -package policy - -import data.acl -import input - -default allow = false - -# METADATA -# scope: rule -# schemas: -# - badref: schema["whocan-input-schema"] -whocan[user] { - access = acl[user] - access[_] == input.operation -}` - - module6 := ` -package policy - -import data.acl -import input - -default allow = false - -# METADATA -# scope: rule -# schemas: -# - data/acl: schema/acl-schema -whocan[user] { - access = acl[user] - access[_] == input.operation -}` - - module7 := ` -package policy - -import data.acl -import input - -default allow = false - -# METADATA -# scope: rule -# schemas: -# - input= schema["whocan-input-schema"] whocan[user] { access = acl[user] access[_] == input.operation @@ -1821,9 +1770,6 @@ whocan[user] { "correct data override": {module: module2, schemaSet: schemaSet}, "incorrect data override": {module: module3, schemaSet: schemaSet, err: "undefined ref: input.user"}, "schema not exist in annotation path": {module: module4, schemaSet: schemaSet, err: "schema does not exist for given path in annotation"}, - "non ref in annotation": {module: module5, schemaSet: schemaSet, err: "expected ref but got"}, - "Ill-structured annotation with bad path": {module: module6, schemaSet: schemaSet, err: "schema is not well formed in annotation"}, - "Ill-structured (invalid) annotation": {module: module7, schemaSet: schemaSet, err: "unable to unmarshall the metadata yaml in comment"}, "empty schema set": {module: module1, schemaSet: nil, err: "schemas need to be supplied for the annotation"}, "overriding ref with length greater than one and not existing": {module: module8, schemaSet: schemaSet, err: "undefined ref: input.apple.banana"}, "overriding ref with length greater than one and existing prefix": {module: module9, schemaSet: schemaSet}, diff --git a/ast/compare.go b/ast/compare.go index c2308bb92c..7f538ecbad 100644 --- a/ast/compare.go +++ b/ast/compare.go @@ -192,6 +192,9 @@ func Compare(a, b interface{}) int { case *Package: b := b.(*Package) return a.Compare(b) + case *Annotations: + b := b.(*Annotations) + return a.Compare(b) case *Module: b := b.(*Module) return a.Compare(b) @@ -251,6 +254,8 @@ func sortOrder(x interface{}) int { return 1001 case *Package: return 1002 + case *Annotations: + return 1003 case *Module: return 10000 } @@ -276,6 +281,25 @@ func importsCompare(a, b []*Import) int { return 0 } +func annotationsCompare(a, b []*Annotations) int { + minLen := len(a) + if len(b) < minLen { + minLen = len(b) + } + for i := 0; i < minLen; i++ { + if cmp := a[i].Compare(b[i]); cmp != 0 { + return cmp + } + } + if len(a) < len(b) { + return -1 + } + if len(b) < len(a) { + return 1 + } + return 0 +} + func rulesCompare(a, b []*Rule) int { minLen := len(a) if len(b) < minLen { diff --git a/ast/compare_test.go b/ast/compare_test.go index 72050406f8..b03c60f52b 100644 --- a/ast/compare_test.go +++ b/ast/compare_test.go @@ -4,7 +4,9 @@ package ast -import "testing" +import ( + "testing" +) func TestCompare(t *testing.T) { @@ -97,4 +99,214 @@ import input.x.z`) if result != -1 { t.Errorf("Expected %v to be less than %v but got: %v", a, b, result) } + + var err error + + a, err = ParseModuleWithOpts("test.rego", `package a + +# METADATA +# scope: rule +# schemas: +# - input: schema.a +p := 7`, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + + b, err = ParseModuleWithOpts("test.rego", `package a + +# METADATA +# scope: rule +# schemas: +# - input: schema.b +p := 7`, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + + result = Compare(a, b) + + if result != -1 { + t.Errorf("Expected %v to be less than %v but got: %v", a, b, result) + } +} + +func TestCompareAnnotations(t *testing.T) { + + tests := []struct { + note string + a string + b string + exp int + }{ + { + note: "same", + a: ` +# METADATA +# scope: a`, + b: ` +# METADATA +# scope: a`, + exp: 0, + }, + { + note: "unknown scope", + a: ` +# METADATA +# scope: rule`, + b: ` +# METADATA +# scope: a`, + exp: 1, + }, + { + note: "unknown scope - less than", + a: ` +# METADATA +# scope: a`, + b: ` +# METADATA +# scope: rule`, + exp: -1, + }, + { + note: "unknown scope - greater than - lexigraphical", + a: ` +# METADATA +# scope: b`, + b: ` +# METADATA +# scope: a`, + exp: 1, + }, + { + note: "unknown scope - less than - lexigraphical", + a: ` +# METADATA +# scope: b`, + b: ` +# METADATA +# scope: c`, + exp: -1, + }, + { + note: "schema", + a: ` +# METADATA +# scope: rule +# schemas: +# - input: schema`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input: schema`, + exp: 0, + }, + { + note: "schema - less than", + a: ` +# METADATA +# scope: rule +# schemas: +# - input.a: schema`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input.b: schema`, + exp: -1, + }, + { + note: "schema - greater than", + a: ` +# METADATA +# scope: rule +# schemas: +# - input.b: schema`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input.a: schema`, + exp: 1, + }, + { + note: "schema - less than (fewer)", + a: ` +# METADATA +# scope: rule +# schemas: +# - input.a: schema`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input.a: schema +# - input.b: schema`, + exp: -1, + }, + { + note: "schema - greater than (more)", + a: ` +# METADATA +# scope: rule +# schemas: +# - input.a: schema +# - input.b: schema`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input.a: schema`, + exp: 1, + }, + { + note: "schema - less than - lexigraphical", + a: ` +# METADATA +# scope: rule +# schemas: +# - input: schema.a`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input: schema.b`, + exp: -1, + }, + { + note: "schema - greater than - lexigraphical", + a: ` +# METADATA +# scope: rule +# schemas: +# - input: schema.c`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input: schema.b`, + exp: 1, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + stmts, _, err := ParseStatementsWithOpts("test.rego", tc.a, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + a := stmts[0].(*Annotations) + stmts, _, err = ParseStatementsWithOpts("test.rego", tc.b, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + b := stmts[0].(*Annotations) + result := a.Compare(b) + if result != tc.exp { + t.Fatalf("Expected %d but got %v for %v and %v", tc.exp, result, a, b) + } + }) + } } diff --git a/ast/parser.go b/ast/parser.go index f0c5373e07..31f09b2616 100644 --- a/ast/parser.go +++ b/ast/parser.go @@ -5,11 +5,11 @@ package ast import ( + "bytes" "encoding/json" "fmt" "io" "math/big" - "strings" "gopkg.in/yaml.v2" @@ -93,75 +93,19 @@ func (p *Parser) WithProcessAnnotation(processAnnotation bool) *Parser { return p } -const metadata = "METADATA" -const ruleScope = "rule" - -// Metadata is used to unmarshal the policy metadata information -type Metadata struct { - Scope string `yaml:"scope"` - Schemas []map[string]string `yaml:"schemas"` -} - -// getAnnotation returns annotations in the comment if any -func (p *Parser) getAnnotation(rule *Rule, endYamlLine int) (Annotations, error) { - var metadataYaml []byte - var startYamlLine, currentYamlLine, prevYamlLine int - for i := 0; i < len(p.s.comments); i++ { - comment := p.s.comments[i] - currentYamlLine = comment.Location.Row - if currentYamlLine > endYamlLine { //comment comes after the rule - not relevant - break - } - - if currentYamlLine != (prevYamlLine + 1) { //comment not part of the same block - not relevant - startYamlLine = 0 - metadataYaml = nil - } - - if strings.HasPrefix((strings.TrimSpace(string(comment.Text))), metadata) && comment.Location.Col == 1 { // found METADATA signalling start in a block comment - startYamlLine = currentYamlLine + 1 - metadataYaml = make([]byte, 0) - } - - if startYamlLine != 0 && currentYamlLine >= startYamlLine && currentYamlLine <= endYamlLine && comment.Location.Col == 1 { //build yaml content from block comment only - metadataYaml = append(metadataYaml, comment.Text...) - metadataYaml = append(metadataYaml, []byte("\n")...) - } - prevYamlLine = currentYamlLine - } - - if prevYamlLine == endYamlLine && len(metadataYaml) > 0 { - metadata := &Metadata{} - err := yaml.Unmarshal(metadataYaml, metadata) - if err != nil { - return nil, fmt.Errorf("unable to unmarshall the metadata yaml in comment") - } - - if metadata.Scope == ruleScope && metadata.Schemas != nil { - var sannot []SchemaAnnotation - for _, schemas := range metadata.Schemas { - for path, schema := range schemas { - sannot = append(sannot, SchemaAnnotation{Path: path, Schema: schema}) - } - } - return &SchemaAnnotations{SchemaAnnotation: sannot, - Scope: ruleScope, - Rule: rule}, nil - } - } - return nil, nil - -} +const ( + annotationScopeRule = "rule" +) // Parse will read the Rego source and parse statements and // comments as they are found. Any errors encountered while // parsing will be accumulated and returned as a list of Errors. -func (p *Parser) Parse() ([]Statement, []*Comment, []Annotations, Errors) { +func (p *Parser) Parse() ([]Statement, []*Comment, Errors) { var err error p.s.s, err = scanner.New(p.r) if err != nil { - return nil, nil, nil, Errors{ + return nil, nil, Errors{ &Error{ Code: ParseErr, Message: err.Error(), @@ -174,7 +118,6 @@ func (p *Parser) Parse() ([]Statement, []*Comment, []Annotations, Errors) { p.scan() var stmts []Statement - var annotations []Annotations // Read from the scanner until the last token is reached or no statements // can be parsed. Attempt to parse package statements, import statements, @@ -209,17 +152,6 @@ func (p *Parser) Parse() ([]Statement, []*Comment, []Annotations, Errors) { if rules := p.parseRules(); rules != nil { for i := range rules { stmts = append(stmts, rules[i]) - // Append schema annotation to rule if there is one, and if processAnnotation option is on - if p.po.ProcessAnnotation { - ruleLoc := rules[i].Location.Row - annot, err := p.getAnnotation(rules[i], ruleLoc-1) - if err != nil { - p.error(rules[i].Location, err.Error()) - } - if annot != nil { - annotations = append(annotations, annot) - } - } } continue } else if len(p.s.errors) > 0 { @@ -237,7 +169,43 @@ func (p *Parser) Parse() ([]Statement, []*Comment, []Annotations, Errors) { break } - return stmts, p.s.comments, annotations, p.s.errors + if p.po.ProcessAnnotation { + stmts = p.parseAnnotations(stmts) + } + + return stmts, p.s.comments, p.s.errors +} + +func (p *Parser) parseAnnotations(stmts []Statement) []Statement { + + var hint = []byte("METADATA") + var curr *metadataParser + var blocks []*metadataParser + + for i := 0; i < len(p.s.comments); i++ { + if curr != nil { + if p.s.comments[i].Location.Row == p.s.comments[i-1].Location.Row+1 && p.s.comments[i].Location.Col == 1 { + curr.Append(p.s.comments[i]) + continue + } + curr = nil + } + if bytes.HasPrefix(bytes.TrimSpace(p.s.comments[i].Text), hint) { + curr = newMetadataParser(p.s.comments[i].Location) + blocks = append(blocks, curr) + } + } + + for _, b := range blocks { + a, err := b.Parse() + if err != nil { + p.error(b.loc, err.Error()) + } else { + stmts = append(stmts, a) + } + } + + return stmts } func (p *Parser) parsePackage() *Package { @@ -1585,3 +1553,84 @@ func (p *Parser) validateDefaultRuleValue(rule *Rule) bool { vis.Walk(rule.Head.Value.Value) return valid } + +type rawAnnotation struct { + Scope string `json:"scope"` + Schemas []rawSchemaAnnotation `json:"schemas"` +} + +type rawSchemaAnnotation map[string]string + +type metadataParser struct { + buf *bytes.Buffer + loc *location.Location +} + +func newMetadataParser(loc *Location) *metadataParser { + return &metadataParser{loc: loc, buf: bytes.NewBuffer(nil)} +} + +func (b *metadataParser) Append(c *Comment) { + b.buf.Write(bytes.TrimPrefix(c.Text, []byte(" "))) + b.buf.WriteByte('\n') +} + +func (b *metadataParser) Parse() (*Annotations, error) { + + var raw rawAnnotation + + if len(bytes.TrimSpace(b.buf.Bytes())) == 0 { + return nil, fmt.Errorf("expected METADATA block, found whitespace") + } + + // TODO(tsandall): how to improve locations of errors? The YAML parser + // doesn't include line numbers in the error API. + if err := yaml.Unmarshal(b.buf.Bytes(), &raw); err != nil { + return nil, err + } + + var result Annotations + result.Scope = raw.Scope + + for _, pair := range raw.Schemas { + var k, v string + for k, v = range pair { + } + kr, err := ParseRef(k) + if err != nil { + return nil, fmt.Errorf("invalid document reference") + } + vr, err := parseSchemaRef(v) + if err != nil { + return nil, err + } + result.Schemas = append(result.Schemas, &SchemaAnnotation{ + Path: kr, + Schema: vr, + }) + } + + result.Location = b.loc + return &result, nil +} + +var errInvalidSchemaRef = fmt.Errorf("invalid schema reference") + +func parseSchemaRef(s string) (Ref, error) { + + term, err := ParseTerm(s) + if err == nil { + switch v := term.Value.(type) { + case Var: + if term.Equal(SchemaRootDocument) { + return SchemaRootRef.Copy(), nil + } + case Ref: + if v.HasPrefix(SchemaRootRef) { + return v, nil + } + } + } + + return nil, errInvalidSchemaRef +} diff --git a/ast/parser_ext.go b/ast/parser_ext.go index a92f8ff19e..5e6aefcc55 100644 --- a/ast/parser_ext.go +++ b/ast/parser_ext.go @@ -425,11 +425,11 @@ func ParseModule(filename, input string) (*Module, error) { // For details on Module objects and their fields, see policy.go. // Empty input will return nil, nil. func ParseModuleWithOpts(filename, input string, popts ParserOptions) (*Module, error) { - stmts, comments, annotations, err := ParseStatementsWithOpts(filename, input, popts) + stmts, comments, err := ParseStatementsWithOpts(filename, input, popts) if err != nil { return nil, err } - return parseModule(filename, stmts, comments, annotations) + return parseModule(filename, stmts, comments) } // ParseBody returns exactly one body. @@ -570,38 +570,30 @@ func (a commentKey) Compare(other commentKey) int { return 0 } -// ParseStatements returns a slice of parsed statements. -// This is the default return value from the parser. +// ParseStatements is deprecated. Use ParseStatementWithOpts instead. func ParseStatements(filename, input string) ([]Statement, []*Comment, error) { - - stmts, comment, _, errs := NewParser().WithFilename(filename).WithReader(bytes.NewBufferString(input)).Parse() - - if len(errs) > 0 { - return nil, nil, errs - } - - return stmts, comment, nil + return ParseStatementsWithOpts(filename, input, ParserOptions{}) } -// ParseStatementsWithOpts returns a slice of parsed statements, and has an additional input ParserOptions -// This is the default return value from the parser. -func ParseStatementsWithOpts(filename, input string, popts ParserOptions) ([]Statement, []*Comment, []Annotations, error) { +// ParseStatementsWithOpts returns a slice of parsed statements. This is the +// default return value from the parser. +func ParseStatementsWithOpts(filename, input string, popts ParserOptions) ([]Statement, []*Comment, error) { parser := NewParser().WithFilename(filename).WithReader(bytes.NewBufferString(input)) if popts.ProcessAnnotation { parser.WithProcessAnnotation(popts.ProcessAnnotation) } - stmts, comment, annotations, errs := parser.Parse() + stmts, comments, errs := parser.Parse() if len(errs) > 0 { - return nil, nil, nil, errs + return nil, nil, errs } - return stmts, comment, annotations, nil + return stmts, comments, nil } -func parseModule(filename string, stmts []Statement, comments []*Comment, annotation []Annotations) (*Module, error) { +func parseModule(filename string, stmts []Statement, comments []*Comment) (*Module, error) { if len(stmts) == 0 { return nil, NewError(ParseErr, &Location{File: filename}, "empty module") @@ -616,14 +608,13 @@ func parseModule(filename string, stmts []Statement, comments []*Comment, annota } mod := &Module{ - Package: _package, - Annotation: annotation, + Package: _package, } // The comments slice only holds comments that were not their own statements. mod.Comments = append(mod.Comments, comments...) - for _, stmt := range stmts[1:] { + for i, stmt := range stmts[1:] { switch stmt := stmt.(type) { case *Import: mod.Imports = append(mod.Imports, stmt) @@ -636,20 +627,42 @@ func parseModule(filename string, stmts []Statement, comments []*Comment, annota errs = append(errs, NewError(ParseErr, stmt[0].Location, err.Error())) } else { mod.Rules = append(mod.Rules, rule) + + // NOTE(tsandall): the statement should now be interpreted as a + // rule so update the statement list. This is important for the + // logic below that associates annotations with statements. + stmts[i+1] = rule } case *Package: errs = append(errs, NewError(ParseErr, stmt.Loc(), "unexpected package")) - case *Comment: // Ignore comments, they're handled above. + case *Annotations: + mod.Annotations = append(mod.Annotations, stmt) + case *Comment: + // Ignore comments, they're handled above. default: panic("illegal value") // Indicates grammar is out-of-sync with code. } } - if len(errs) == 0 { - return mod, nil + if len(errs) > 0 { + return nil, errs + } + + // Find first non-annotation statement following each annotation and attach + // the annotation to that statement. + for _, a := range mod.Annotations { + for _, stmt := range stmts { + _, ok := stmt.(*Annotations) + if !ok { + if stmt.Loc().Row > a.Location.Row { + a.Node = stmt + break + } + } + } } - return nil, errs + return mod, nil } func setRuleModule(rule *Rule, module *Module) { diff --git a/ast/parser_test.go b/ast/parser_test.go index 4304653585..f1c62834c5 100644 --- a/ast/parser_test.go +++ b/ast/parser_test.go @@ -2662,13 +2662,21 @@ else = { `), curElse.Head.Value.Location) } -func TestGetAnnotation(t *testing.T) { +func TestAnnotations(t *testing.T) { + + dataServers := MustParseRef("data.servers") + dataNetworks := MustParseRef("data.networks") + dataPorts := MustParseRef("data.ports") + + schemaServers := MustParseRef("schema.servers") + schemaNetworks := MustParseRef("schema.networks") + schemaPorts := MustParseRef("schema.ports") tests := []struct { note string module string expNumComments int - expAnnotations []Annotations + expAnnotations []*Annotations expError string }{ { @@ -2680,7 +2688,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing a single schema # METADATA # scope: rule # schemas: @@ -2690,12 +2697,15 @@ public_servers[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 5, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.servers", Schema: "schemas.servers"}), - Scope: ruleScope, - }), + expNumComments: 4, + expAnnotations: []*Annotations{ + &Annotations{ + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + }, + Scope: annotationScopeRule, + }, + }, }, { note: "Multiple annotations on multiple lines", @@ -2706,7 +2716,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2718,12 +2727,17 @@ public_servers[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 7, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.servers", Schema: "schemas.servers"}, SchemaAnnotation{Path: "data.networks", Schema: "schemas.networks"}, SchemaAnnotation{Path: "data.ports", Schema: "schemas.ports"}), - Scope: ruleScope, - }), + expNumComments: 6, + expAnnotations: []*Annotations{ + &Annotations{ + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + {Path: dataNetworks, Schema: schemaNetworks}, + {Path: dataPorts, Schema: schemaPorts}, + }, + Scope: annotationScopeRule, + }, + }, }, { note: "Comment in between metadata and rule (valid)", @@ -2734,25 +2748,30 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: # - data.servers: schema.servers # - data.networks: schema.networks # - data.ports: schema.ports -#This is a comment after the metadata yaml + +# This is a comment after the metadata YAML public_servers[server] { server = servers[i]; server.ports[j] = ports[k].id ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 8, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.servers", Schema: "schemas.servers"}, SchemaAnnotation{Path: "data.networks", Schema: "schemas.networks"}, SchemaAnnotation{Path: "data.ports", Schema: "schemas.ports"}), - Scope: ruleScope, - }), + expNumComments: 7, + expAnnotations: []*Annotations{ + &Annotations{ + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + {Path: dataNetworks, Schema: schemaNetworks}, + {Path: dataPorts, Schema: schemaPorts}, + }, + Scope: annotationScopeRule, + }, + }, }, { note: "Empty comment line in between metadata and rule (valid)", @@ -2763,7 +2782,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2776,12 +2794,17 @@ public_servers[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 8, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.servers", Schema: "schemas.servers"}, SchemaAnnotation{Path: "data.networks", Schema: "schemas.networks"}, SchemaAnnotation{Path: "data.ports", Schema: "schemas.ports"}), - Scope: ruleScope, - }), + expNumComments: 7, + expAnnotations: []*Annotations{ + &Annotations{ + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + {Path: dataNetworks, Schema: schemaNetworks}, + {Path: dataPorts, Schema: schemaPorts}, + }, + Scope: annotationScopeRule, + }, + }, }, { note: "Ill-structured (invalid) metadata start", @@ -2792,7 +2815,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2805,11 +2827,10 @@ public_servers[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 8, - expAnnotations: nil, + expError: "rego_parse_error: yaml: line 7: could not find expected ':'", }, { - note: "Ill-structured (valid) annotation", + note: "Ill-structured (invalid) annotation document path", module: ` package opa.examples @@ -2817,22 +2838,38 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: -# - data/servers: schemas/servers +# - data/servers: schema.servers public_servers[server] { server = servers[i]; server.ports[j] = ports[k].id ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 5, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data/servers", Schema: "schemas/servers"}), - Scope: ruleScope, - }), + expNumComments: 4, + expError: "rego_parse_error: invalid document reference", + }, + { + note: "Ill-structured (invalid) annotation schema path", + module: ` +package opa.examples + +import data.servers +import data.networks +import data.ports + +# METADATA +# scope: rule +# schemas: +# - data.servers: schema/servers +public_servers[server] { + server = servers[i]; server.ports[j] = ports[k].id + ports[k].networks[l] = networks[m].id; + networks[m].public = true +}`, + expNumComments: 4, + expError: "rego_parse_error: invalid schema reference", }, { note: "Ill-structured (invalid) annotation", @@ -2843,7 +2880,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2854,8 +2890,7 @@ public_servers[server] { networks[m].public = true }`, expNumComments: 5, - expAnnotations: nil, - expError: "unable to unmarshall the metadata yaml in comment", + expError: "rego_parse_error: yaml: unmarshal errors:\n line 3: cannot unmarshal !!str", }, { note: "Indentation error in yaml", @@ -2866,7 +2901,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2878,9 +2912,8 @@ public_servers[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 7, - expAnnotations: nil, - expError: "unable to unmarshall the metadata yaml in comment", + expNumComments: 6, + expError: "rego_parse_error: yaml: line 3: did not find expected key", }, { note: "Multiple rules with and without metadata", @@ -2891,7 +2924,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2910,12 +2942,17 @@ public_servers_1[server] { networks[m].public = true server.typo # won't catch this type error since rule has no schema metadata }`, - expNumComments: 8, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.servers", Schema: "schemas.servers"}, SchemaAnnotation{Path: "data.networks", Schema: "schemas.networks"}, SchemaAnnotation{Path: "data.ports", Schema: "schemas.ports"}), - Scope: ruleScope, - }), + expNumComments: 7, + expAnnotations: []*Annotations{ + &Annotations{ + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + {Path: dataNetworks, Schema: schemaNetworks}, + {Path: dataPorts, Schema: schemaPorts}, + }, + Scope: annotationScopeRule, + }, + }, }, { note: "Multiple rules with metadata", @@ -2926,7 +2963,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2935,7 +2971,6 @@ public_servers[server] { server = servers[i] } -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2945,17 +2980,48 @@ public_servers_1[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 11, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.servers", Schema: "schemas.servers"}), - Scope: ruleScope, - Rule: MustParseRule(`public_servers[server] { server = servers[i] }`), - }, &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.networks", Schema: "schemas.networks"}, SchemaAnnotation{Path: "data.ports", Schema: "schemas.ports"}), - Scope: ruleScope, - Rule: MustParseRule(`public_servers_1[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`), - }), + expNumComments: 9, + expAnnotations: []*Annotations{ + &Annotations{ + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + }, + Scope: annotationScopeRule, + Node: MustParseRule(`public_servers[server] { server = servers[i] }`), + }, + &Annotations{ + Schemas: []*SchemaAnnotation{ + + {Path: dataNetworks, Schema: schemaNetworks}, + {Path: dataPorts, Schema: schemaPorts}, + }, + Scope: annotationScopeRule, + Node: MustParseRule(`public_servers_1[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`), + }, + }, + }, + { + note: "Empty annotation error due to whitespace following METADATA hint", + module: `package test + +# METADATA + +# scope: rule +p { input.x > 7 }`, + expError: "test.rego:3: rego_parse_error: expected METADATA block, found whitespace", + }, + { + note: "Annotation on constant", + module: ` +package test + +# METADATA +# scope: rule +p := 7`, + expNumComments: 2, + expAnnotations: []*Annotations{ + {Scope: annotationScopeRule}, + }, }, } @@ -2974,31 +3040,11 @@ public_servers_1[server] { } if len(mod.Comments) != tc.expNumComments { - t.Errorf("Expected %v comments but got %v", tc.expNumComments, len(mod.Comments)) - } - - annotations := mod.Annotation - if len(annotations) != len(tc.expAnnotations) { - t.Errorf("Expected %v annotations but got %v", len(tc.expAnnotations), len(annotations)) + t.Fatalf("Expected %v comments but got %v", tc.expNumComments, len(mod.Comments)) } - for _, annot := range annotations { - schemaAnnots, ok := annot.(*SchemaAnnotations) - if !ok { - t.Fatalf("Expected err: %v but no error from parse module", tc.expError) - } - for _, tcannot := range tc.expAnnotations { - tcschemaAnnots, ok := tcannot.(*SchemaAnnotations) - if !ok { - t.Fatalf("Expected err: %v but no error from parse module", tc.expError) - } - if schemaAnnots.Scope == ruleScope && tcschemaAnnots.Scope == ruleScope && tcschemaAnnots.Rule != nil && schemaAnnots.Rule.Head.Name == tcschemaAnnots.Rule.Head.Name { - if len(schemaAnnots.SchemaAnnotation) != len(tcschemaAnnots.SchemaAnnotation) { - t.Errorf("Expected %v annotations but got %v", len(schemaAnnots.SchemaAnnotation), len(tcschemaAnnots.SchemaAnnotation)) - } - } - } - + if annotationsCompare(tc.expAnnotations, mod.Annotations) != 0 { + t.Fatalf("expected %v but got %v", tc.expAnnotations, mod.Annotations) } }) } diff --git a/ast/policy.go b/ast/policy.go index 2b27ac499d..c640744245 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -131,11 +131,11 @@ type ( // within a namespace (defined by the package) and optional // dependencies on external documents (defined by imports). Module struct { - Package *Package `json:"package"` - Imports []*Import `json:"imports,omitempty"` - Rules []*Rule `json:"rules,omitempty"` - Comments []*Comment `json:"comments,omitempty"` - Annotation []Annotations `json:"annotation,omitempty"` + Package *Package `json:"package"` + Imports []*Import `json:"imports,omitempty"` + Annotations []*Annotations `json:"annotations,omitempty"` + Rules []*Rule `json:"rules,omitempty"` + Comments []*Comment `json:"comments,omitempty"` } // Comment contains the raw text from the comment in the definition. @@ -144,27 +144,18 @@ type ( Location *Location } - // Annotations contains information extracted from metadata in comments - Annotations interface { - annotationMaker() - - // NOTE(tsandall): these are temporary interfaces that are required to support copy operations. - // When we get rid of the rule pointers, these may not be needed. - copy(Node) Annotations - node() Node - } - - // SchemaAnnotations contains information about schemas - SchemaAnnotations struct { - SchemaAnnotation []SchemaAnnotation `json:"schemaannotation"` - Scope string `json:"scope"` - Rule *Rule `json:"-"` + // Annotations represents metadata attached to other AST nodes such as rules. + Annotations struct { + Node Node `json:"-"` + Location *Location `json:"-"` + Scope string `json:"scope"` + Schemas []*SchemaAnnotation `json:"schemas,omitempty"` } - // SchemaAnnotation contains information about a schema + // SchemaAnnotation contains a schema declaration for the document identified by the path. SchemaAnnotation struct { - Path string `json:"path"` - Schema string `json:"schema"` + Path Ref `json:"path"` + Schema Ref `json:"schema"` } // Package represents the namespace of the documents produced @@ -239,17 +230,108 @@ type ( } ) -func (s *SchemaAnnotations) copy(node Node) Annotations { +func (s *Annotations) String() string { + bs, _ := json.Marshal(s) + return string(bs) +} + +// Loc returns the location of this annotation. +func (s *Annotations) Loc() *Location { + return s.Location +} + +// SetLoc updates the location of this annotation. +func (s *Annotations) SetLoc(l *Location) { + s.Location = l +} + +// Compare returns an integer indicating if s is less than, equal to, or greater +// than other. +func (s *Annotations) Compare(other *Annotations) int { + + if cmp := scopeCompare(s.Scope, other.Scope); cmp != 0 { + return cmp + } + + max := len(s.Schemas) + if len(other.Schemas) < max { + max = len(other.Schemas) + } + + for i := 0; i < max; i++ { + if cmp := s.Schemas[i].Compare(other.Schemas[i]); cmp != 0 { + return cmp + } + } + + if len(s.Schemas) > len(other.Schemas) { + return 1 + } else if len(s.Schemas) < len(other.Schemas) { + return -1 + } + + return 0 +} + +// Copy returns a deep copy of s. +func (s *Annotations) Copy(node Node) *Annotations { cpy := *s - cpy.Rule = node.(*Rule) + cpy.Schemas = make([]*SchemaAnnotation, len(s.Schemas)) + for i := range cpy.Schemas { + cpy.Schemas[i] = s.Schemas[i].Copy() + } + cpy.Node = node return &cpy } -func (s *SchemaAnnotations) node() Node { - return s.Rule +// Copy returns a deep copy of s. +func (s *SchemaAnnotation) Copy() *SchemaAnnotation { + cpy := *s + return &cpy +} + +// Compare returns an integer indicating if s is less than, equal to, or greater +// than other. +func (s *SchemaAnnotation) Compare(other *SchemaAnnotation) int { + + if cmp := s.Path.Compare(other.Path); cmp != 0 { + return cmp + } + + if cmp := s.Schema.Compare(other.Schema); cmp != 0 { + return cmp + } + + return 0 +} + +func scopeCompare(s1, s2 string) int { + + o1 := scopeOrder(s1) + o2 := scopeOrder(s2) + + if o2 < o1 { + return 1 + } else if o2 > o1 { + return -1 + } + + if s1 < s2 { + return -1 + } else if s2 < s1 { + return 1 + } + + return 0 } -func (*SchemaAnnotations) annotationMaker() {} +func scopeOrder(s string) int { + switch s { + case annotationScopeRule: + return 1 + } + return 0 +} // Compare returns an integer indicating whether mod is less than, equal to, // or greater than other. @@ -268,6 +350,9 @@ func (mod *Module) Compare(other *Module) int { if cmp := importsCompare(mod.Imports, other.Imports); cmp != 0 { return cmp } + if cmp := annotationsCompare(mod.Annotations, other.Annotations); cmp != 0 { + return cmp + } return rulesCompare(mod.Rules, other.Rules) } @@ -276,32 +361,38 @@ func (mod *Module) Copy() *Module { cpy := *mod cpy.Rules = make([]*Rule, len(mod.Rules)) - // NOTE(tsandall): only construct the map if annotations are present. This is a temporary - // workaround to deal with the lack of a stable index mapping annotations to rules. - var rules map[Node]Node - if len(mod.Annotation) > 0 { - rules = make(map[Node]Node, len(mod.Rules)) + var nodes map[Node]Node + + if len(mod.Annotations) > 0 { + nodes = make(map[Node]Node) } for i := range mod.Rules { cpy.Rules[i] = mod.Rules[i].Copy() cpy.Rules[i].Module = &cpy - - if rules != nil { - rules[mod.Rules[i]] = cpy.Rules[i] + if nodes != nil { + nodes[mod.Rules[i]] = cpy.Rules[i] } } - cpy.Annotation = make([]Annotations, len(mod.Annotation)) - for i := range mod.Annotation { - cpy.Annotation[i] = mod.Annotation[i].copy(rules[mod.Annotation[i].node()]) - } - cpy.Imports = make([]*Import, len(mod.Imports)) for i := range mod.Imports { cpy.Imports[i] = mod.Imports[i].Copy() + if nodes != nil { + nodes[mod.Imports[i]] = cpy.Imports[i] + } } + cpy.Package = mod.Package.Copy() + if nodes != nil { + nodes[mod.Package] = cpy.Package + } + + cpy.Annotations = make([]*Annotations, len(mod.Annotations)) + for i := range mod.Annotations { + cpy.Annotations[i] = mod.Annotations[i].Copy(nodes[mod.Annotations[i].Node]) + } + return &cpy } @@ -312,16 +403,36 @@ func (mod *Module) Equal(other *Module) bool { func (mod *Module) String() string { buf := []string{} + + byNode := map[Node][]*Annotations{} + for _, a := range mod.Annotations { + byNode[a.Node] = append(byNode[a.Node], a) + } + + appendAnnotationStrings := func(buf []string, node Node) []string { + if as, ok := byNode[node]; ok { + for i := range as { + buf = append(buf, "# METADATA") + buf = append(buf, "# "+as[i].String()) + } + } + return buf + } + + buf = appendAnnotationStrings(buf, mod.Package) buf = append(buf, mod.Package.String()) + if len(mod.Imports) > 0 { buf = append(buf, "") for _, imp := range mod.Imports { + buf = appendAnnotationStrings(buf, imp) buf = append(buf, imp.String()) } } if len(mod.Rules) > 0 { buf = append(buf, "") for _, rule := range mod.Rules { + buf = appendAnnotationStrings(buf, rule) buf = append(buf, rule.String()) } } diff --git a/ast/policy_test.go b/ast/policy_test.go index 8ca9e377c9..bc273a09b5 100644 --- a/ast/policy_test.go +++ b/ast/policy_test.go @@ -16,7 +16,7 @@ import ( func TestModuleJSONRoundTrip(t *testing.T) { - mod := MustParseModule(`package a.b.c + mod, err := ParseModuleWithOpts("test.rego", `package a.b.c import data.x.y as z import data.u.i @@ -40,7 +40,15 @@ a = true { xs = {a: b | input.y[a] = "foo"; b = input.z["bar"]} } b = true { xs = {{"x": a[i].a} | a[i].n = "bob"; b[x]} } call_values { f(x) != g(x) } assigned := 1 -`) + +# METADATA +# scope: rule +metadata := 7 +`, ParserOptions{ProcessAnnotation: true}) + + if err != nil { + t.Fatal(err) + } bs, err := json.Marshal(mod) if err != nil { @@ -61,6 +69,10 @@ assigned := 1 if mod.Rules[3].Path().String() != "data.a.b.c.t" { t.Fatal("expected path data.a.b.c.t for 4th rule in module but got:", mod.Rules[3].Path()) } + + if len(roundtrip.Annotations) != 1 { + t.Fatal("expected exactly one annotation") + } } func TestBodyEmptyJSON(t *testing.T) { @@ -515,6 +527,48 @@ func TestSomeDeclString(t *testing.T) { } } +func TestAnnotationsString(t *testing.T) { + a := &Annotations{ + Scope: "foo", + Schemas: []*SchemaAnnotation{ + { + Path: MustParseRef("data.bar"), + Schema: MustParseRef("schema.baz"), + }, + }, + } + + // NOTE(tsandall): for now, annotations are represented as JSON objects + // which are a subset of YAML. We could improve this in the future. + exp := `{"scope":"foo","schemas":[{"path":[{"type":"var","value":"data"},{"type":"string","value":"bar"}],"schema":[{"type":"var","value":"schema"},{"type":"string","value":"baz"}]}]}` + + if exp != a.String() { + t.Fatalf("expected %q but got %q", exp, a.String()) + } +} + +func TestModuleStringAnnotations(t *testing.T) { + module, err := ParseModuleWithOpts("test.rego", `package test + +# METADATA +# scope: rule +p := 7`, ParserOptions{ProcessAnnotation: true}) + + if err != nil { + t.Fatal(err) + } + + exp := `package test + +# METADATA +# {"scope":"rule"} +p := 7 { true }` + + if module.String() != exp { + t.Fatalf("expected %q but got %q", exp, module.String()) + } +} + func TestCommentCopy(t *testing.T) { comment := &Comment{ Text: []byte("foo bar baz"), diff --git a/ast/transform.go b/ast/transform.go index 7e9af9cc26..c3509edc33 100644 --- a/ast/transform.go +++ b/ast/transform.go @@ -59,6 +59,15 @@ func Transform(t Transformer, x interface{}) (interface{}, error) { return nil, fmt.Errorf("illegal transform: %T != %T", y.Rules[i], rule) } } + for i := range y.Annotations { + a, err := Transform(t, y.Annotations[i]) + if err != nil { + return nil, err + } + if y.Annotations[i], ok = a.(*Annotations); !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", y.Annotations[i], a) + } + } for i := range y.Comments { comment, err := Transform(t, y.Comments[i]) if err != nil { diff --git a/ast/transform_test.go b/ast/transform_test.go index f9c71d34c2..72a38a9c5b 100644 --- a/ast/transform_test.go +++ b/ast/transform_test.go @@ -4,7 +4,9 @@ package ast -import "testing" +import ( + "testing" +) func TestTransform(t *testing.T) { module := MustParseModule(`package ex.this @@ -60,3 +62,45 @@ foo(x) = y { split(x, "that", y) } } } + +func TestTransformAnnotations(t *testing.T) { + + module, err := ParseModuleWithOpts("test.rego", `package test + +# METADATA +# scope: rule +p := 7`, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + + result, err := Transform(&GenericTransformer{ + func(x interface{}) (interface{}, error) { + if s, ok := x.(*Annotations); ok { + cpy := *s + cpy.Scope = "deadbeef" + return &cpy, nil + } + return x, nil + }, + }, module) + + resultMod, ok := result.(*Module) + if !ok { + t.Fatalf("Expected module from transform but got: %v", result) + } + + exp, err := ParseModuleWithOpts("test.rego", `package test + +# METADATA +# scope: deadbeef +p := 7`, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + + if resultMod.Compare(exp) != 0 { + t.Fatalf("expected:\n\n%v\n\ngot:\n\n%v", exp, resultMod) + } + +} diff --git a/ast/visit.go b/ast/visit.go index 105fb58ab8..139c4de3f5 100644 --- a/ast/visit.go +++ b/ast/visit.go @@ -52,6 +52,9 @@ func walk(v Visitor, x interface{}) { for _, r := range x.Rules { Walk(w, r) } + for _, a := range x.Annotations { + Walk(w, a) + } for _, c := range x.Comments { Walk(w, c) } @@ -280,6 +283,9 @@ func (vis *GenericVisitor) Walk(x interface{}) { for _, r := range x.Rules { vis.Walk(r) } + for _, a := range x.Annotations { + vis.Walk(a) + } for _, c := range x.Comments { vis.Walk(c) } @@ -398,6 +404,9 @@ func (vis *BeforeAfterVisitor) Walk(x interface{}) { for _, r := range x.Rules { vis.Walk(r) } + for _, a := range x.Annotations { + vis.Walk(a) + } for _, c := range x.Comments { vis.Walk(c) } diff --git a/ast/visit_test.go b/ast/visit_test.go index 7c338e86fb..1b3694508b 100644 --- a/ast/visit_test.go +++ b/ast/visit_test.go @@ -298,6 +298,28 @@ fn([x, y]) = z { json.unmarshal(x, z); z > y } } } +func TestVisitorAnnotations(t *testing.T) { + + module, err := ParseModuleWithOpts("test.rego", `package test + +# METADATA +# scope: rule +p := 7`, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + + vis := &testVis{} + + NewGenericVisitor(vis.Visit).Walk(module) + + exp := 20 + + if len(vis.elems) != exp { + t.Fatalf("expected %d elements but got %v: %v", exp, len(vis.elems), vis.elems) + } +} + func TestWalkVars(t *testing.T) { x := MustParseBody(`x = 1; data.abc[2] = y; y[z] = [q | q = 1]`) found := NewVarSet() diff --git a/cmd/parse.go b/cmd/parse.go index 25e789feff..06b6e49681 100644 --- a/cmd/parse.go +++ b/cmd/parse.go @@ -49,7 +49,7 @@ func parse(args []string, stdout io.Writer, stderr io.Writer) int { return 0 } - result, err := loader.Rego(args[0]) + result, err := loader.RegoWithOpts(args[0], ast.ParserOptions{ProcessAnnotation: true}) switch parseParams.format.String() { case parseFormatJSON: diff --git a/loader/loader.go b/loader/loader.go index a4c07ced77..970a950edc 100644 --- a/loader/loader.go +++ b/loader/loader.go @@ -108,12 +108,12 @@ type descriptor struct { } type fileLoader struct { - metrics metrics.Metrics - bvc *bundle.VerificationConfig - skipVerify bool - descriptors []*descriptor - files map[string]bundle.FileInfo - processAnnotation bool + metrics metrics.Metrics + bvc *bundle.VerificationConfig + skipVerify bool + descriptors []*descriptor + files map[string]bundle.FileInfo + opts ast.ParserOptions } // WithMetrics provides the metrics instance to use while loading @@ -136,7 +136,7 @@ func (fl *fileLoader) WithSkipBundleVerification(skipVerify bool) FileLoader { // WithProcessAnnotation enables or disables processing of schema annotations on rules func (fl *fileLoader) WithProcessAnnotation(processAnnotation bool) FileLoader { - fl.processAnnotation = processAnnotation + fl.opts.ProcessAnnotation = processAnnotation return fl } @@ -156,7 +156,7 @@ func (fl fileLoader) Filtered(paths []string, filter Filter) (*Result, error) { return err } - result, err := loadKnownTypes(path, bs, fl.metrics, fl.processAnnotation) + result, err := loadKnownTypes(path, bs, fl.metrics, fl.opts) if err != nil { if !isUnrecognizedFile(err) { return err @@ -164,7 +164,7 @@ func (fl fileLoader) Filtered(paths []string, filter Filter) (*Result, error) { if depth > 0 { return nil } - result, err = loadFileForAnyType(path, bs, fl.metrics) + result, err = loadFileForAnyType(path, bs, fl.metrics, fl.opts) if err != nil { return err } @@ -187,7 +187,7 @@ func (fl fileLoader) AsBundle(path string) (*bundle.Bundle, error) { WithMetrics(fl.metrics). WithBundleVerificationConfig(fl.bvc). WithSkipBundleVerification(fl.skipVerify). - WithProcessAnnotations(fl.processAnnotation) + WithProcessAnnotations(fl.opts.ProcessAnnotation) // For bundle directories add the full path in front of module file names // to simplify debugging. @@ -387,8 +387,13 @@ func AllRegos(paths []string) (*Result, error) { }) } -// Rego returns a RegoFile object loaded from the given path. +// Rego is deprecated. Use RegoWithOpts instead. func Rego(path string) (*RegoFile, error) { + return RegoWithOpts(path, ast.ParserOptions{}) +} + +// RegoWithOpts returns a RegoFile object loaded from the given path. +func RegoWithOpts(path string, opts ast.ParserOptions) (*RegoFile, error) { path, err := fileurl.Clean(path) if err != nil { return nil, err @@ -397,7 +402,7 @@ func Rego(path string) (*RegoFile, error) { if err != nil { return nil, err } - return loadRego(path, bs, metrics.New()) + return loadRego(path, bs, metrics.New(), opts) } // CleanPath returns the normalized version of a path that can be used as an identifier. @@ -578,12 +583,12 @@ func allRec(path string, filter Filter, errors *Errors, loaded *Result, depth in } } -func loadKnownTypes(path string, bs []byte, m metrics.Metrics, processAnnotation bool) (interface{}, error) { +func loadKnownTypes(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (interface{}, error) { switch filepath.Ext(path) { case ".json": return loadJSON(path, bs, m) case ".rego": - return loadRego(path, bs, m, processAnnotation) + return loadRego(path, bs, m, opts) case ".yaml", ".yml": return loadYAML(path, bs, m) default: @@ -598,8 +603,8 @@ func loadKnownTypes(path string, bs []byte, m metrics.Metrics, processAnnotation return nil, unrecognizedFile(path) } -func loadFileForAnyType(path string, bs []byte, m metrics.Metrics) (interface{}, error) { - module, err := loadRego(path, bs, m) +func loadFileForAnyType(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (interface{}, error) { + module, err := loadRego(path, bs, m, opts) if err == nil { return module, nil } @@ -620,17 +625,11 @@ func loadBundleFile(path string, bs []byte, m metrics.Metrics) (bundle.Bundle, e return br.Read() } -func loadRego(path string, bs []byte, m metrics.Metrics, parserOptions ...bool) (*RegoFile, error) { +func loadRego(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (*RegoFile, error) { m.Timer(metrics.RegoModuleParse).Start() var module *ast.Module var err error - if len(parserOptions) == 1 { - module, err = ast.ParseModuleWithOpts(path, string(bs), ast.ParserOptions{ - ProcessAnnotation: parserOptions[0], - }) - } else { - module, err = ast.ParseModule(path, string(bs)) - } + module, err = ast.ParseModuleWithOpts(path, string(bs), opts) m.Timer(metrics.RegoModuleParse).Stop() if err != nil { return nil, err