From d3adb906f0bec616fccf411bf724b147261eb913 Mon Sep 17 00:00:00 2001 From: Torin Sandall Date: Mon, 19 Apr 2021 11:34:57 -0400 Subject: [PATCH] ast: Misc. refactoring on annotations support This commit combines a bunch of refactoring on annotations to support future work. Specifically: * Annotations are now normal AST nodes/statements. This means that annotations store locations and also implement String() and Compare(). Annotations are now correctly compared during module comparison and annotations are included in the module string representation (before annotations would be dropped when the module String() function was called.) Also, the visitor and transformer functions support annotations now. * Annotations are no longer hidden behind an interface. Instead, there is a single annotation struct that we can evolve over time. It was unclear how the Annotations interface was going to work in the long-term (e.g., callers would not be able to define their own annotation types since the parser needs to be aware of them.) With this change, Annotations are just structs now. We can extend the struct as needed going forward. Custom data can be stored in a dedicated field. * Annotation parsing has been refactored. We now attach annotations to the statement following the annotation. The parser will reject METADATA blocks that contain whitespace between the METADATA hint and the YAML block. Similarly, we no longer support trailing unindented comments that follow the METADATA block. Users can inject whitespace after the YAML block if they want to include trailing comments. * The opa parse subcommand now enables annotation processing. Signed-off-by: Torin Sandall --- ast/check.go | 33 +++--- ast/check_test.go | 54 ---------- ast/compare.go | 24 +++++ ast/compare_test.go | 214 +++++++++++++++++++++++++++++++++++++- ast/parser.go | 199 ++++++++++++++++++++++-------------- ast/parser_ext.go | 65 +++++++----- ast/parser_test.go | 232 +++++++++++++++++++++++++----------------- ast/policy.go | 193 +++++++++++++++++++++++++++-------- ast/policy_test.go | 58 ++++++++++- ast/transform.go | 9 ++ ast/transform_test.go | 46 ++++++++- ast/visit.go | 9 ++ ast/visit_test.go | 22 ++++ cmd/parse.go | 2 +- loader/loader.go | 47 +++++---- 15 files changed, 869 insertions(+), 338 deletions(-) diff --git a/ast/check.go b/ast/check.go index 57652686b7..edba7b729d 100644 --- a/ast/check.go +++ b/ast/check.go @@ -1129,31 +1129,29 @@ func getObjectType(ref Ref, o types.Type, rule *Rule, d *types.DynamicProperty) return getObjectTypeRec(keys, o, d), nil } -func getRuleAnnotation(rule *Rule) (sannots []SchemaAnnotation) { - for _, annot := range rule.Module.Annotation { - schemaAnnots, ok := annot.(*SchemaAnnotations) - if ok && schemaAnnots.Scope == ruleScope && schemaAnnots.Rule == rule { - return schemaAnnots.SchemaAnnotation +func getRuleAnnotation(rule *Rule) (result []*SchemaAnnotation) { + for _, a := range rule.Module.Annotations { + other, ok := a.Node.(*Rule) + if !ok { + continue + } + if other == rule { + result = append(result, a.Schemas...) } } - return nil + return result } // NOTE: Currently, annotations must preceed the rule. In the future, this // restriction could be relaxed with other kinds of annotation scopes. -func processAnnotation(ss *SchemaSet, annot SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, types.Type, *Error) { +func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, env *TypeEnv, rule *Rule) (Ref, types.Type, *Error) { if ss == nil { return nil, nil, NewError(TypeErr, rule.Location, "schemas need to be supplied for the annotation: %s", annot.Schema) } - schemaRef, err := ParseRef(annot.Schema) - if err != nil { - return nil, nil, NewError(TypeErr, rule.Location, "schema is not well formed in annotation: %s", annot.Schema) - } - - schema := ss.Get(schemaRef) + schema := ss.Get(annot.Schema) if schema == nil { - return nil, nil, NewError(TypeErr, rule.Location, "schema does not exist for given path in annotation: %s", schemaRef.String()) + return nil, nil, NewError(TypeErr, rule.Location, "schema does not exist for given path in annotation: %s", annot.Schema) } tpe, err := loadSchema(schema) @@ -1161,10 +1159,5 @@ func processAnnotation(ss *SchemaSet, annot SchemaAnnotation, env *TypeEnv, rule return nil, nil, NewError(TypeErr, rule.Location, err.Error()) } - ref, err := ParseRef(annot.Path) - if err != nil { - return nil, nil, NewError(TypeErr, rule.Location, err.Error()) - } - - return ref, tpe, nil + return annot.Path, tpe, nil } diff --git a/ast/check_test.go b/ast/check_test.go index 4dbe7a3487..07279c32cd 100644 --- a/ast/check_test.go +++ b/ast/check_test.go @@ -1446,57 +1446,6 @@ default allow = false # scope: rule # schemas: # - input: schema["badpath"] -whocan[user] { - access = acl[user] - access[_] == input.operation -}` - - module5 := ` -package policy - -import data.acl -import input - -default allow = false - -# METADATA -# scope: rule -# schemas: -# - badref: schema["whocan-input-schema"] -whocan[user] { - access = acl[user] - access[_] == input.operation -}` - - module6 := ` -package policy - -import data.acl -import input - -default allow = false - -# METADATA -# scope: rule -# schemas: -# - data/acl: schema/acl-schema -whocan[user] { - access = acl[user] - access[_] == input.operation -}` - - module7 := ` -package policy - -import data.acl -import input - -default allow = false - -# METADATA -# scope: rule -# schemas: -# - input= schema["whocan-input-schema"] whocan[user] { access = acl[user] access[_] == input.operation @@ -1821,9 +1770,6 @@ whocan[user] { "correct data override": {module: module2, schemaSet: schemaSet}, "incorrect data override": {module: module3, schemaSet: schemaSet, err: "undefined ref: input.user"}, "schema not exist in annotation path": {module: module4, schemaSet: schemaSet, err: "schema does not exist for given path in annotation"}, - "non ref in annotation": {module: module5, schemaSet: schemaSet, err: "expected ref but got"}, - "Ill-structured annotation with bad path": {module: module6, schemaSet: schemaSet, err: "schema is not well formed in annotation"}, - "Ill-structured (invalid) annotation": {module: module7, schemaSet: schemaSet, err: "unable to unmarshall the metadata yaml in comment"}, "empty schema set": {module: module1, schemaSet: nil, err: "schemas need to be supplied for the annotation"}, "overriding ref with length greater than one and not existing": {module: module8, schemaSet: schemaSet, err: "undefined ref: input.apple.banana"}, "overriding ref with length greater than one and existing prefix": {module: module9, schemaSet: schemaSet}, diff --git a/ast/compare.go b/ast/compare.go index c2308bb92c..7f538ecbad 100644 --- a/ast/compare.go +++ b/ast/compare.go @@ -192,6 +192,9 @@ func Compare(a, b interface{}) int { case *Package: b := b.(*Package) return a.Compare(b) + case *Annotations: + b := b.(*Annotations) + return a.Compare(b) case *Module: b := b.(*Module) return a.Compare(b) @@ -251,6 +254,8 @@ func sortOrder(x interface{}) int { return 1001 case *Package: return 1002 + case *Annotations: + return 1003 case *Module: return 10000 } @@ -276,6 +281,25 @@ func importsCompare(a, b []*Import) int { return 0 } +func annotationsCompare(a, b []*Annotations) int { + minLen := len(a) + if len(b) < minLen { + minLen = len(b) + } + for i := 0; i < minLen; i++ { + if cmp := a[i].Compare(b[i]); cmp != 0 { + return cmp + } + } + if len(a) < len(b) { + return -1 + } + if len(b) < len(a) { + return 1 + } + return 0 +} + func rulesCompare(a, b []*Rule) int { minLen := len(a) if len(b) < minLen { diff --git a/ast/compare_test.go b/ast/compare_test.go index 72050406f8..b03c60f52b 100644 --- a/ast/compare_test.go +++ b/ast/compare_test.go @@ -4,7 +4,9 @@ package ast -import "testing" +import ( + "testing" +) func TestCompare(t *testing.T) { @@ -97,4 +99,214 @@ import input.x.z`) if result != -1 { t.Errorf("Expected %v to be less than %v but got: %v", a, b, result) } + + var err error + + a, err = ParseModuleWithOpts("test.rego", `package a + +# METADATA +# scope: rule +# schemas: +# - input: schema.a +p := 7`, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + + b, err = ParseModuleWithOpts("test.rego", `package a + +# METADATA +# scope: rule +# schemas: +# - input: schema.b +p := 7`, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + + result = Compare(a, b) + + if result != -1 { + t.Errorf("Expected %v to be less than %v but got: %v", a, b, result) + } +} + +func TestCompareAnnotations(t *testing.T) { + + tests := []struct { + note string + a string + b string + exp int + }{ + { + note: "same", + a: ` +# METADATA +# scope: a`, + b: ` +# METADATA +# scope: a`, + exp: 0, + }, + { + note: "unknown scope", + a: ` +# METADATA +# scope: rule`, + b: ` +# METADATA +# scope: a`, + exp: 1, + }, + { + note: "unknown scope - less than", + a: ` +# METADATA +# scope: a`, + b: ` +# METADATA +# scope: rule`, + exp: -1, + }, + { + note: "unknown scope - greater than - lexigraphical", + a: ` +# METADATA +# scope: b`, + b: ` +# METADATA +# scope: a`, + exp: 1, + }, + { + note: "unknown scope - less than - lexigraphical", + a: ` +# METADATA +# scope: b`, + b: ` +# METADATA +# scope: c`, + exp: -1, + }, + { + note: "schema", + a: ` +# METADATA +# scope: rule +# schemas: +# - input: schema`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input: schema`, + exp: 0, + }, + { + note: "schema - less than", + a: ` +# METADATA +# scope: rule +# schemas: +# - input.a: schema`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input.b: schema`, + exp: -1, + }, + { + note: "schema - greater than", + a: ` +# METADATA +# scope: rule +# schemas: +# - input.b: schema`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input.a: schema`, + exp: 1, + }, + { + note: "schema - less than (fewer)", + a: ` +# METADATA +# scope: rule +# schemas: +# - input.a: schema`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input.a: schema +# - input.b: schema`, + exp: -1, + }, + { + note: "schema - greater than (more)", + a: ` +# METADATA +# scope: rule +# schemas: +# - input.a: schema +# - input.b: schema`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input.a: schema`, + exp: 1, + }, + { + note: "schema - less than - lexigraphical", + a: ` +# METADATA +# scope: rule +# schemas: +# - input: schema.a`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input: schema.b`, + exp: -1, + }, + { + note: "schema - greater than - lexigraphical", + a: ` +# METADATA +# scope: rule +# schemas: +# - input: schema.c`, + b: ` +# METADATA +# scope: rule +# schemas: +# - input: schema.b`, + exp: 1, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + stmts, _, err := ParseStatementsWithOpts("test.rego", tc.a, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + a := stmts[0].(*Annotations) + stmts, _, err = ParseStatementsWithOpts("test.rego", tc.b, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + b := stmts[0].(*Annotations) + result := a.Compare(b) + if result != tc.exp { + t.Fatalf("Expected %d but got %v for %v and %v", tc.exp, result, a, b) + } + }) + } } diff --git a/ast/parser.go b/ast/parser.go index f0c5373e07..31f09b2616 100644 --- a/ast/parser.go +++ b/ast/parser.go @@ -5,11 +5,11 @@ package ast import ( + "bytes" "encoding/json" "fmt" "io" "math/big" - "strings" "gopkg.in/yaml.v2" @@ -93,75 +93,19 @@ func (p *Parser) WithProcessAnnotation(processAnnotation bool) *Parser { return p } -const metadata = "METADATA" -const ruleScope = "rule" - -// Metadata is used to unmarshal the policy metadata information -type Metadata struct { - Scope string `yaml:"scope"` - Schemas []map[string]string `yaml:"schemas"` -} - -// getAnnotation returns annotations in the comment if any -func (p *Parser) getAnnotation(rule *Rule, endYamlLine int) (Annotations, error) { - var metadataYaml []byte - var startYamlLine, currentYamlLine, prevYamlLine int - for i := 0; i < len(p.s.comments); i++ { - comment := p.s.comments[i] - currentYamlLine = comment.Location.Row - if currentYamlLine > endYamlLine { //comment comes after the rule - not relevant - break - } - - if currentYamlLine != (prevYamlLine + 1) { //comment not part of the same block - not relevant - startYamlLine = 0 - metadataYaml = nil - } - - if strings.HasPrefix((strings.TrimSpace(string(comment.Text))), metadata) && comment.Location.Col == 1 { // found METADATA signalling start in a block comment - startYamlLine = currentYamlLine + 1 - metadataYaml = make([]byte, 0) - } - - if startYamlLine != 0 && currentYamlLine >= startYamlLine && currentYamlLine <= endYamlLine && comment.Location.Col == 1 { //build yaml content from block comment only - metadataYaml = append(metadataYaml, comment.Text...) - metadataYaml = append(metadataYaml, []byte("\n")...) - } - prevYamlLine = currentYamlLine - } - - if prevYamlLine == endYamlLine && len(metadataYaml) > 0 { - metadata := &Metadata{} - err := yaml.Unmarshal(metadataYaml, metadata) - if err != nil { - return nil, fmt.Errorf("unable to unmarshall the metadata yaml in comment") - } - - if metadata.Scope == ruleScope && metadata.Schemas != nil { - var sannot []SchemaAnnotation - for _, schemas := range metadata.Schemas { - for path, schema := range schemas { - sannot = append(sannot, SchemaAnnotation{Path: path, Schema: schema}) - } - } - return &SchemaAnnotations{SchemaAnnotation: sannot, - Scope: ruleScope, - Rule: rule}, nil - } - } - return nil, nil - -} +const ( + annotationScopeRule = "rule" +) // Parse will read the Rego source and parse statements and // comments as they are found. Any errors encountered while // parsing will be accumulated and returned as a list of Errors. -func (p *Parser) Parse() ([]Statement, []*Comment, []Annotations, Errors) { +func (p *Parser) Parse() ([]Statement, []*Comment, Errors) { var err error p.s.s, err = scanner.New(p.r) if err != nil { - return nil, nil, nil, Errors{ + return nil, nil, Errors{ &Error{ Code: ParseErr, Message: err.Error(), @@ -174,7 +118,6 @@ func (p *Parser) Parse() ([]Statement, []*Comment, []Annotations, Errors) { p.scan() var stmts []Statement - var annotations []Annotations // Read from the scanner until the last token is reached or no statements // can be parsed. Attempt to parse package statements, import statements, @@ -209,17 +152,6 @@ func (p *Parser) Parse() ([]Statement, []*Comment, []Annotations, Errors) { if rules := p.parseRules(); rules != nil { for i := range rules { stmts = append(stmts, rules[i]) - // Append schema annotation to rule if there is one, and if processAnnotation option is on - if p.po.ProcessAnnotation { - ruleLoc := rules[i].Location.Row - annot, err := p.getAnnotation(rules[i], ruleLoc-1) - if err != nil { - p.error(rules[i].Location, err.Error()) - } - if annot != nil { - annotations = append(annotations, annot) - } - } } continue } else if len(p.s.errors) > 0 { @@ -237,7 +169,43 @@ func (p *Parser) Parse() ([]Statement, []*Comment, []Annotations, Errors) { break } - return stmts, p.s.comments, annotations, p.s.errors + if p.po.ProcessAnnotation { + stmts = p.parseAnnotations(stmts) + } + + return stmts, p.s.comments, p.s.errors +} + +func (p *Parser) parseAnnotations(stmts []Statement) []Statement { + + var hint = []byte("METADATA") + var curr *metadataParser + var blocks []*metadataParser + + for i := 0; i < len(p.s.comments); i++ { + if curr != nil { + if p.s.comments[i].Location.Row == p.s.comments[i-1].Location.Row+1 && p.s.comments[i].Location.Col == 1 { + curr.Append(p.s.comments[i]) + continue + } + curr = nil + } + if bytes.HasPrefix(bytes.TrimSpace(p.s.comments[i].Text), hint) { + curr = newMetadataParser(p.s.comments[i].Location) + blocks = append(blocks, curr) + } + } + + for _, b := range blocks { + a, err := b.Parse() + if err != nil { + p.error(b.loc, err.Error()) + } else { + stmts = append(stmts, a) + } + } + + return stmts } func (p *Parser) parsePackage() *Package { @@ -1585,3 +1553,84 @@ func (p *Parser) validateDefaultRuleValue(rule *Rule) bool { vis.Walk(rule.Head.Value.Value) return valid } + +type rawAnnotation struct { + Scope string `json:"scope"` + Schemas []rawSchemaAnnotation `json:"schemas"` +} + +type rawSchemaAnnotation map[string]string + +type metadataParser struct { + buf *bytes.Buffer + loc *location.Location +} + +func newMetadataParser(loc *Location) *metadataParser { + return &metadataParser{loc: loc, buf: bytes.NewBuffer(nil)} +} + +func (b *metadataParser) Append(c *Comment) { + b.buf.Write(bytes.TrimPrefix(c.Text, []byte(" "))) + b.buf.WriteByte('\n') +} + +func (b *metadataParser) Parse() (*Annotations, error) { + + var raw rawAnnotation + + if len(bytes.TrimSpace(b.buf.Bytes())) == 0 { + return nil, fmt.Errorf("expected METADATA block, found whitespace") + } + + // TODO(tsandall): how to improve locations of errors? The YAML parser + // doesn't include line numbers in the error API. + if err := yaml.Unmarshal(b.buf.Bytes(), &raw); err != nil { + return nil, err + } + + var result Annotations + result.Scope = raw.Scope + + for _, pair := range raw.Schemas { + var k, v string + for k, v = range pair { + } + kr, err := ParseRef(k) + if err != nil { + return nil, fmt.Errorf("invalid document reference") + } + vr, err := parseSchemaRef(v) + if err != nil { + return nil, err + } + result.Schemas = append(result.Schemas, &SchemaAnnotation{ + Path: kr, + Schema: vr, + }) + } + + result.Location = b.loc + return &result, nil +} + +var errInvalidSchemaRef = fmt.Errorf("invalid schema reference") + +func parseSchemaRef(s string) (Ref, error) { + + term, err := ParseTerm(s) + if err == nil { + switch v := term.Value.(type) { + case Var: + if term.Equal(SchemaRootDocument) { + return SchemaRootRef.Copy(), nil + } + case Ref: + if v.HasPrefix(SchemaRootRef) { + return v, nil + } + } + } + + return nil, errInvalidSchemaRef +} diff --git a/ast/parser_ext.go b/ast/parser_ext.go index a92f8ff19e..5e6aefcc55 100644 --- a/ast/parser_ext.go +++ b/ast/parser_ext.go @@ -425,11 +425,11 @@ func ParseModule(filename, input string) (*Module, error) { // For details on Module objects and their fields, see policy.go. // Empty input will return nil, nil. func ParseModuleWithOpts(filename, input string, popts ParserOptions) (*Module, error) { - stmts, comments, annotations, err := ParseStatementsWithOpts(filename, input, popts) + stmts, comments, err := ParseStatementsWithOpts(filename, input, popts) if err != nil { return nil, err } - return parseModule(filename, stmts, comments, annotations) + return parseModule(filename, stmts, comments) } // ParseBody returns exactly one body. @@ -570,38 +570,30 @@ func (a commentKey) Compare(other commentKey) int { return 0 } -// ParseStatements returns a slice of parsed statements. -// This is the default return value from the parser. +// ParseStatements is deprecated. Use ParseStatementWithOpts instead. func ParseStatements(filename, input string) ([]Statement, []*Comment, error) { - - stmts, comment, _, errs := NewParser().WithFilename(filename).WithReader(bytes.NewBufferString(input)).Parse() - - if len(errs) > 0 { - return nil, nil, errs - } - - return stmts, comment, nil + return ParseStatementsWithOpts(filename, input, ParserOptions{}) } -// ParseStatementsWithOpts returns a slice of parsed statements, and has an additional input ParserOptions -// This is the default return value from the parser. -func ParseStatementsWithOpts(filename, input string, popts ParserOptions) ([]Statement, []*Comment, []Annotations, error) { +// ParseStatementsWithOpts returns a slice of parsed statements. This is the +// default return value from the parser. +func ParseStatementsWithOpts(filename, input string, popts ParserOptions) ([]Statement, []*Comment, error) { parser := NewParser().WithFilename(filename).WithReader(bytes.NewBufferString(input)) if popts.ProcessAnnotation { parser.WithProcessAnnotation(popts.ProcessAnnotation) } - stmts, comment, annotations, errs := parser.Parse() + stmts, comments, errs := parser.Parse() if len(errs) > 0 { - return nil, nil, nil, errs + return nil, nil, errs } - return stmts, comment, annotations, nil + return stmts, comments, nil } -func parseModule(filename string, stmts []Statement, comments []*Comment, annotation []Annotations) (*Module, error) { +func parseModule(filename string, stmts []Statement, comments []*Comment) (*Module, error) { if len(stmts) == 0 { return nil, NewError(ParseErr, &Location{File: filename}, "empty module") @@ -616,14 +608,13 @@ func parseModule(filename string, stmts []Statement, comments []*Comment, annota } mod := &Module{ - Package: _package, - Annotation: annotation, + Package: _package, } // The comments slice only holds comments that were not their own statements. mod.Comments = append(mod.Comments, comments...) - for _, stmt := range stmts[1:] { + for i, stmt := range stmts[1:] { switch stmt := stmt.(type) { case *Import: mod.Imports = append(mod.Imports, stmt) @@ -636,20 +627,42 @@ func parseModule(filename string, stmts []Statement, comments []*Comment, annota errs = append(errs, NewError(ParseErr, stmt[0].Location, err.Error())) } else { mod.Rules = append(mod.Rules, rule) + + // NOTE(tsandall): the statement should now be interpreted as a + // rule so update the statement list. This is important for the + // logic below that associates annotations with statements. + stmts[i+1] = rule } case *Package: errs = append(errs, NewError(ParseErr, stmt.Loc(), "unexpected package")) - case *Comment: // Ignore comments, they're handled above. + case *Annotations: + mod.Annotations = append(mod.Annotations, stmt) + case *Comment: + // Ignore comments, they're handled above. default: panic("illegal value") // Indicates grammar is out-of-sync with code. } } - if len(errs) == 0 { - return mod, nil + if len(errs) > 0 { + return nil, errs + } + + // Find first non-annotation statement following each annotation and attach + // the annotation to that statement. + for _, a := range mod.Annotations { + for _, stmt := range stmts { + _, ok := stmt.(*Annotations) + if !ok { + if stmt.Loc().Row > a.Location.Row { + a.Node = stmt + break + } + } + } } - return nil, errs + return mod, nil } func setRuleModule(rule *Rule, module *Module) { diff --git a/ast/parser_test.go b/ast/parser_test.go index 4304653585..f1c62834c5 100644 --- a/ast/parser_test.go +++ b/ast/parser_test.go @@ -2662,13 +2662,21 @@ else = { `), curElse.Head.Value.Location) } -func TestGetAnnotation(t *testing.T) { +func TestAnnotations(t *testing.T) { + + dataServers := MustParseRef("data.servers") + dataNetworks := MustParseRef("data.networks") + dataPorts := MustParseRef("data.ports") + + schemaServers := MustParseRef("schema.servers") + schemaNetworks := MustParseRef("schema.networks") + schemaPorts := MustParseRef("schema.ports") tests := []struct { note string module string expNumComments int - expAnnotations []Annotations + expAnnotations []*Annotations expError string }{ { @@ -2680,7 +2688,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing a single schema # METADATA # scope: rule # schemas: @@ -2690,12 +2697,15 @@ public_servers[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 5, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.servers", Schema: "schemas.servers"}), - Scope: ruleScope, - }), + expNumComments: 4, + expAnnotations: []*Annotations{ + &Annotations{ + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + }, + Scope: annotationScopeRule, + }, + }, }, { note: "Multiple annotations on multiple lines", @@ -2706,7 +2716,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2718,12 +2727,17 @@ public_servers[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 7, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.servers", Schema: "schemas.servers"}, SchemaAnnotation{Path: "data.networks", Schema: "schemas.networks"}, SchemaAnnotation{Path: "data.ports", Schema: "schemas.ports"}), - Scope: ruleScope, - }), + expNumComments: 6, + expAnnotations: []*Annotations{ + &Annotations{ + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + {Path: dataNetworks, Schema: schemaNetworks}, + {Path: dataPorts, Schema: schemaPorts}, + }, + Scope: annotationScopeRule, + }, + }, }, { note: "Comment in between metadata and rule (valid)", @@ -2734,25 +2748,30 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: # - data.servers: schema.servers # - data.networks: schema.networks # - data.ports: schema.ports -#This is a comment after the metadata yaml + +# This is a comment after the metadata YAML public_servers[server] { server = servers[i]; server.ports[j] = ports[k].id ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 8, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.servers", Schema: "schemas.servers"}, SchemaAnnotation{Path: "data.networks", Schema: "schemas.networks"}, SchemaAnnotation{Path: "data.ports", Schema: "schemas.ports"}), - Scope: ruleScope, - }), + expNumComments: 7, + expAnnotations: []*Annotations{ + &Annotations{ + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + {Path: dataNetworks, Schema: schemaNetworks}, + {Path: dataPorts, Schema: schemaPorts}, + }, + Scope: annotationScopeRule, + }, + }, }, { note: "Empty comment line in between metadata and rule (valid)", @@ -2763,7 +2782,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2776,12 +2794,17 @@ public_servers[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 8, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.servers", Schema: "schemas.servers"}, SchemaAnnotation{Path: "data.networks", Schema: "schemas.networks"}, SchemaAnnotation{Path: "data.ports", Schema: "schemas.ports"}), - Scope: ruleScope, - }), + expNumComments: 7, + expAnnotations: []*Annotations{ + &Annotations{ + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + {Path: dataNetworks, Schema: schemaNetworks}, + {Path: dataPorts, Schema: schemaPorts}, + }, + Scope: annotationScopeRule, + }, + }, }, { note: "Ill-structured (invalid) metadata start", @@ -2792,7 +2815,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2805,11 +2827,10 @@ public_servers[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 8, - expAnnotations: nil, + expError: "rego_parse_error: yaml: line 7: could not find expected ':'", }, { - note: "Ill-structured (valid) annotation", + note: "Ill-structured (invalid) annotation document path", module: ` package opa.examples @@ -2817,22 +2838,38 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: -# - data/servers: schemas/servers +# - data/servers: schema.servers public_servers[server] { server = servers[i]; server.ports[j] = ports[k].id ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 5, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data/servers", Schema: "schemas/servers"}), - Scope: ruleScope, - }), + expNumComments: 4, + expError: "rego_parse_error: invalid document reference", + }, + { + note: "Ill-structured (invalid) annotation schema path", + module: ` +package opa.examples + +import data.servers +import data.networks +import data.ports + +# METADATA +# scope: rule +# schemas: +# - data.servers: schema/servers +public_servers[server] { + server = servers[i]; server.ports[j] = ports[k].id + ports[k].networks[l] = networks[m].id; + networks[m].public = true +}`, + expNumComments: 4, + expError: "rego_parse_error: invalid schema reference", }, { note: "Ill-structured (invalid) annotation", @@ -2843,7 +2880,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2854,8 +2890,7 @@ public_servers[server] { networks[m].public = true }`, expNumComments: 5, - expAnnotations: nil, - expError: "unable to unmarshall the metadata yaml in comment", + expError: "rego_parse_error: yaml: unmarshal errors:\n line 3: cannot unmarshal !!str", }, { note: "Indentation error in yaml", @@ -2866,7 +2901,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2878,9 +2912,8 @@ public_servers[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 7, - expAnnotations: nil, - expError: "unable to unmarshall the metadata yaml in comment", + expNumComments: 6, + expError: "rego_parse_error: yaml: line 3: did not find expected key", }, { note: "Multiple rules with and without metadata", @@ -2891,7 +2924,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2910,12 +2942,17 @@ public_servers_1[server] { networks[m].public = true server.typo # won't catch this type error since rule has no schema metadata }`, - expNumComments: 8, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.servers", Schema: "schemas.servers"}, SchemaAnnotation{Path: "data.networks", Schema: "schemas.networks"}, SchemaAnnotation{Path: "data.ports", Schema: "schemas.ports"}), - Scope: ruleScope, - }), + expNumComments: 7, + expAnnotations: []*Annotations{ + &Annotations{ + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + {Path: dataNetworks, Schema: schemaNetworks}, + {Path: dataPorts, Schema: schemaPorts}, + }, + Scope: annotationScopeRule, + }, + }, }, { note: "Multiple rules with metadata", @@ -2926,7 +2963,6 @@ import data.servers import data.networks import data.ports -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2935,7 +2971,6 @@ public_servers[server] { server = servers[i] } -#Schema annotation for this rule referencing three schemas # METADATA # scope: rule # schemas: @@ -2945,17 +2980,48 @@ public_servers_1[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`, - expNumComments: 11, - expAnnotations: append(make([]Annotations, 0), - &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.servers", Schema: "schemas.servers"}), - Scope: ruleScope, - Rule: MustParseRule(`public_servers[server] { server = servers[i] }`), - }, &SchemaAnnotations{ - SchemaAnnotation: append(make([]SchemaAnnotation, 0), SchemaAnnotation{Path: "data.networks", Schema: "schemas.networks"}, SchemaAnnotation{Path: "data.ports", Schema: "schemas.ports"}), - Scope: ruleScope, - Rule: MustParseRule(`public_servers_1[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`), - }), + expNumComments: 9, + expAnnotations: []*Annotations{ + &Annotations{ + Schemas: []*SchemaAnnotation{ + {Path: dataServers, Schema: schemaServers}, + }, + Scope: annotationScopeRule, + Node: MustParseRule(`public_servers[server] { server = servers[i] }`), + }, + &Annotations{ + Schemas: []*SchemaAnnotation{ + + {Path: dataNetworks, Schema: schemaNetworks}, + {Path: dataPorts, Schema: schemaPorts}, + }, + Scope: annotationScopeRule, + Node: MustParseRule(`public_servers_1[server] { ports[k].networks[l] = networks[m].id; networks[m].public = true }`), + }, + }, + }, + { + note: "Empty annotation error due to whitespace following METADATA hint", + module: `package test + +# METADATA + +# scope: rule +p { input.x > 7 }`, + expError: "test.rego:3: rego_parse_error: expected METADATA block, found whitespace", + }, + { + note: "Annotation on constant", + module: ` +package test + +# METADATA +# scope: rule +p := 7`, + expNumComments: 2, + expAnnotations: []*Annotations{ + {Scope: annotationScopeRule}, + }, }, } @@ -2974,31 +3040,11 @@ public_servers_1[server] { } if len(mod.Comments) != tc.expNumComments { - t.Errorf("Expected %v comments but got %v", tc.expNumComments, len(mod.Comments)) - } - - annotations := mod.Annotation - if len(annotations) != len(tc.expAnnotations) { - t.Errorf("Expected %v annotations but got %v", len(tc.expAnnotations), len(annotations)) + t.Fatalf("Expected %v comments but got %v", tc.expNumComments, len(mod.Comments)) } - for _, annot := range annotations { - schemaAnnots, ok := annot.(*SchemaAnnotations) - if !ok { - t.Fatalf("Expected err: %v but no error from parse module", tc.expError) - } - for _, tcannot := range tc.expAnnotations { - tcschemaAnnots, ok := tcannot.(*SchemaAnnotations) - if !ok { - t.Fatalf("Expected err: %v but no error from parse module", tc.expError) - } - if schemaAnnots.Scope == ruleScope && tcschemaAnnots.Scope == ruleScope && tcschemaAnnots.Rule != nil && schemaAnnots.Rule.Head.Name == tcschemaAnnots.Rule.Head.Name { - if len(schemaAnnots.SchemaAnnotation) != len(tcschemaAnnots.SchemaAnnotation) { - t.Errorf("Expected %v annotations but got %v", len(schemaAnnots.SchemaAnnotation), len(tcschemaAnnots.SchemaAnnotation)) - } - } - } - + if annotationsCompare(tc.expAnnotations, mod.Annotations) != 0 { + t.Fatalf("expected %v but got %v", tc.expAnnotations, mod.Annotations) } }) } diff --git a/ast/policy.go b/ast/policy.go index 2b27ac499d..c640744245 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -131,11 +131,11 @@ type ( // within a namespace (defined by the package) and optional // dependencies on external documents (defined by imports). Module struct { - Package *Package `json:"package"` - Imports []*Import `json:"imports,omitempty"` - Rules []*Rule `json:"rules,omitempty"` - Comments []*Comment `json:"comments,omitempty"` - Annotation []Annotations `json:"annotation,omitempty"` + Package *Package `json:"package"` + Imports []*Import `json:"imports,omitempty"` + Annotations []*Annotations `json:"annotations,omitempty"` + Rules []*Rule `json:"rules,omitempty"` + Comments []*Comment `json:"comments,omitempty"` } // Comment contains the raw text from the comment in the definition. @@ -144,27 +144,18 @@ type ( Location *Location } - // Annotations contains information extracted from metadata in comments - Annotations interface { - annotationMaker() - - // NOTE(tsandall): these are temporary interfaces that are required to support copy operations. - // When we get rid of the rule pointers, these may not be needed. - copy(Node) Annotations - node() Node - } - - // SchemaAnnotations contains information about schemas - SchemaAnnotations struct { - SchemaAnnotation []SchemaAnnotation `json:"schemaannotation"` - Scope string `json:"scope"` - Rule *Rule `json:"-"` + // Annotations represents metadata attached to other AST nodes such as rules. + Annotations struct { + Node Node `json:"-"` + Location *Location `json:"-"` + Scope string `json:"scope"` + Schemas []*SchemaAnnotation `json:"schemas,omitempty"` } - // SchemaAnnotation contains information about a schema + // SchemaAnnotation contains a schema declaration for the document identified by the path. SchemaAnnotation struct { - Path string `json:"path"` - Schema string `json:"schema"` + Path Ref `json:"path"` + Schema Ref `json:"schema"` } // Package represents the namespace of the documents produced @@ -239,17 +230,108 @@ type ( } ) -func (s *SchemaAnnotations) copy(node Node) Annotations { +func (s *Annotations) String() string { + bs, _ := json.Marshal(s) + return string(bs) +} + +// Loc returns the location of this annotation. +func (s *Annotations) Loc() *Location { + return s.Location +} + +// SetLoc updates the location of this annotation. +func (s *Annotations) SetLoc(l *Location) { + s.Location = l +} + +// Compare returns an integer indicating if s is less than, equal to, or greater +// than other. +func (s *Annotations) Compare(other *Annotations) int { + + if cmp := scopeCompare(s.Scope, other.Scope); cmp != 0 { + return cmp + } + + max := len(s.Schemas) + if len(other.Schemas) < max { + max = len(other.Schemas) + } + + for i := 0; i < max; i++ { + if cmp := s.Schemas[i].Compare(other.Schemas[i]); cmp != 0 { + return cmp + } + } + + if len(s.Schemas) > len(other.Schemas) { + return 1 + } else if len(s.Schemas) < len(other.Schemas) { + return -1 + } + + return 0 +} + +// Copy returns a deep copy of s. +func (s *Annotations) Copy(node Node) *Annotations { cpy := *s - cpy.Rule = node.(*Rule) + cpy.Schemas = make([]*SchemaAnnotation, len(s.Schemas)) + for i := range cpy.Schemas { + cpy.Schemas[i] = s.Schemas[i].Copy() + } + cpy.Node = node return &cpy } -func (s *SchemaAnnotations) node() Node { - return s.Rule +// Copy returns a deep copy of s. +func (s *SchemaAnnotation) Copy() *SchemaAnnotation { + cpy := *s + return &cpy +} + +// Compare returns an integer indicating if s is less than, equal to, or greater +// than other. +func (s *SchemaAnnotation) Compare(other *SchemaAnnotation) int { + + if cmp := s.Path.Compare(other.Path); cmp != 0 { + return cmp + } + + if cmp := s.Schema.Compare(other.Schema); cmp != 0 { + return cmp + } + + return 0 +} + +func scopeCompare(s1, s2 string) int { + + o1 := scopeOrder(s1) + o2 := scopeOrder(s2) + + if o2 < o1 { + return 1 + } else if o2 > o1 { + return -1 + } + + if s1 < s2 { + return -1 + } else if s2 < s1 { + return 1 + } + + return 0 } -func (*SchemaAnnotations) annotationMaker() {} +func scopeOrder(s string) int { + switch s { + case annotationScopeRule: + return 1 + } + return 0 +} // Compare returns an integer indicating whether mod is less than, equal to, // or greater than other. @@ -268,6 +350,9 @@ func (mod *Module) Compare(other *Module) int { if cmp := importsCompare(mod.Imports, other.Imports); cmp != 0 { return cmp } + if cmp := annotationsCompare(mod.Annotations, other.Annotations); cmp != 0 { + return cmp + } return rulesCompare(mod.Rules, other.Rules) } @@ -276,32 +361,38 @@ func (mod *Module) Copy() *Module { cpy := *mod cpy.Rules = make([]*Rule, len(mod.Rules)) - // NOTE(tsandall): only construct the map if annotations are present. This is a temporary - // workaround to deal with the lack of a stable index mapping annotations to rules. - var rules map[Node]Node - if len(mod.Annotation) > 0 { - rules = make(map[Node]Node, len(mod.Rules)) + var nodes map[Node]Node + + if len(mod.Annotations) > 0 { + nodes = make(map[Node]Node) } for i := range mod.Rules { cpy.Rules[i] = mod.Rules[i].Copy() cpy.Rules[i].Module = &cpy - - if rules != nil { - rules[mod.Rules[i]] = cpy.Rules[i] + if nodes != nil { + nodes[mod.Rules[i]] = cpy.Rules[i] } } - cpy.Annotation = make([]Annotations, len(mod.Annotation)) - for i := range mod.Annotation { - cpy.Annotation[i] = mod.Annotation[i].copy(rules[mod.Annotation[i].node()]) - } - cpy.Imports = make([]*Import, len(mod.Imports)) for i := range mod.Imports { cpy.Imports[i] = mod.Imports[i].Copy() + if nodes != nil { + nodes[mod.Imports[i]] = cpy.Imports[i] + } } + cpy.Package = mod.Package.Copy() + if nodes != nil { + nodes[mod.Package] = cpy.Package + } + + cpy.Annotations = make([]*Annotations, len(mod.Annotations)) + for i := range mod.Annotations { + cpy.Annotations[i] = mod.Annotations[i].Copy(nodes[mod.Annotations[i].Node]) + } + return &cpy } @@ -312,16 +403,36 @@ func (mod *Module) Equal(other *Module) bool { func (mod *Module) String() string { buf := []string{} + + byNode := map[Node][]*Annotations{} + for _, a := range mod.Annotations { + byNode[a.Node] = append(byNode[a.Node], a) + } + + appendAnnotationStrings := func(buf []string, node Node) []string { + if as, ok := byNode[node]; ok { + for i := range as { + buf = append(buf, "# METADATA") + buf = append(buf, "# "+as[i].String()) + } + } + return buf + } + + buf = appendAnnotationStrings(buf, mod.Package) buf = append(buf, mod.Package.String()) + if len(mod.Imports) > 0 { buf = append(buf, "") for _, imp := range mod.Imports { + buf = appendAnnotationStrings(buf, imp) buf = append(buf, imp.String()) } } if len(mod.Rules) > 0 { buf = append(buf, "") for _, rule := range mod.Rules { + buf = appendAnnotationStrings(buf, rule) buf = append(buf, rule.String()) } } diff --git a/ast/policy_test.go b/ast/policy_test.go index 8ca9e377c9..bc273a09b5 100644 --- a/ast/policy_test.go +++ b/ast/policy_test.go @@ -16,7 +16,7 @@ import ( func TestModuleJSONRoundTrip(t *testing.T) { - mod := MustParseModule(`package a.b.c + mod, err := ParseModuleWithOpts("test.rego", `package a.b.c import data.x.y as z import data.u.i @@ -40,7 +40,15 @@ a = true { xs = {a: b | input.y[a] = "foo"; b = input.z["bar"]} } b = true { xs = {{"x": a[i].a} | a[i].n = "bob"; b[x]} } call_values { f(x) != g(x) } assigned := 1 -`) + +# METADATA +# scope: rule +metadata := 7 +`, ParserOptions{ProcessAnnotation: true}) + + if err != nil { + t.Fatal(err) + } bs, err := json.Marshal(mod) if err != nil { @@ -61,6 +69,10 @@ assigned := 1 if mod.Rules[3].Path().String() != "data.a.b.c.t" { t.Fatal("expected path data.a.b.c.t for 4th rule in module but got:", mod.Rules[3].Path()) } + + if len(roundtrip.Annotations) != 1 { + t.Fatal("expected exactly one annotation") + } } func TestBodyEmptyJSON(t *testing.T) { @@ -515,6 +527,48 @@ func TestSomeDeclString(t *testing.T) { } } +func TestAnnotationsString(t *testing.T) { + a := &Annotations{ + Scope: "foo", + Schemas: []*SchemaAnnotation{ + { + Path: MustParseRef("data.bar"), + Schema: MustParseRef("schema.baz"), + }, + }, + } + + // NOTE(tsandall): for now, annotations are represented as JSON objects + // which are a subset of YAML. We could improve this in the future. + exp := `{"scope":"foo","schemas":[{"path":[{"type":"var","value":"data"},{"type":"string","value":"bar"}],"schema":[{"type":"var","value":"schema"},{"type":"string","value":"baz"}]}]}` + + if exp != a.String() { + t.Fatalf("expected %q but got %q", exp, a.String()) + } +} + +func TestModuleStringAnnotations(t *testing.T) { + module, err := ParseModuleWithOpts("test.rego", `package test + +# METADATA +# scope: rule +p := 7`, ParserOptions{ProcessAnnotation: true}) + + if err != nil { + t.Fatal(err) + } + + exp := `package test + +# METADATA +# {"scope":"rule"} +p := 7 { true }` + + if module.String() != exp { + t.Fatalf("expected %q but got %q", exp, module.String()) + } +} + func TestCommentCopy(t *testing.T) { comment := &Comment{ Text: []byte("foo bar baz"), diff --git a/ast/transform.go b/ast/transform.go index 7e9af9cc26..c3509edc33 100644 --- a/ast/transform.go +++ b/ast/transform.go @@ -59,6 +59,15 @@ func Transform(t Transformer, x interface{}) (interface{}, error) { return nil, fmt.Errorf("illegal transform: %T != %T", y.Rules[i], rule) } } + for i := range y.Annotations { + a, err := Transform(t, y.Annotations[i]) + if err != nil { + return nil, err + } + if y.Annotations[i], ok = a.(*Annotations); !ok { + return nil, fmt.Errorf("illegal transform: %T != %T", y.Annotations[i], a) + } + } for i := range y.Comments { comment, err := Transform(t, y.Comments[i]) if err != nil { diff --git a/ast/transform_test.go b/ast/transform_test.go index f9c71d34c2..72a38a9c5b 100644 --- a/ast/transform_test.go +++ b/ast/transform_test.go @@ -4,7 +4,9 @@ package ast -import "testing" +import ( + "testing" +) func TestTransform(t *testing.T) { module := MustParseModule(`package ex.this @@ -60,3 +62,45 @@ foo(x) = y { split(x, "that", y) } } } + +func TestTransformAnnotations(t *testing.T) { + + module, err := ParseModuleWithOpts("test.rego", `package test + +# METADATA +# scope: rule +p := 7`, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + + result, err := Transform(&GenericTransformer{ + func(x interface{}) (interface{}, error) { + if s, ok := x.(*Annotations); ok { + cpy := *s + cpy.Scope = "deadbeef" + return &cpy, nil + } + return x, nil + }, + }, module) + + resultMod, ok := result.(*Module) + if !ok { + t.Fatalf("Expected module from transform but got: %v", result) + } + + exp, err := ParseModuleWithOpts("test.rego", `package test + +# METADATA +# scope: deadbeef +p := 7`, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + + if resultMod.Compare(exp) != 0 { + t.Fatalf("expected:\n\n%v\n\ngot:\n\n%v", exp, resultMod) + } + +} diff --git a/ast/visit.go b/ast/visit.go index 105fb58ab8..139c4de3f5 100644 --- a/ast/visit.go +++ b/ast/visit.go @@ -52,6 +52,9 @@ func walk(v Visitor, x interface{}) { for _, r := range x.Rules { Walk(w, r) } + for _, a := range x.Annotations { + Walk(w, a) + } for _, c := range x.Comments { Walk(w, c) } @@ -280,6 +283,9 @@ func (vis *GenericVisitor) Walk(x interface{}) { for _, r := range x.Rules { vis.Walk(r) } + for _, a := range x.Annotations { + vis.Walk(a) + } for _, c := range x.Comments { vis.Walk(c) } @@ -398,6 +404,9 @@ func (vis *BeforeAfterVisitor) Walk(x interface{}) { for _, r := range x.Rules { vis.Walk(r) } + for _, a := range x.Annotations { + vis.Walk(a) + } for _, c := range x.Comments { vis.Walk(c) } diff --git a/ast/visit_test.go b/ast/visit_test.go index 7c338e86fb..1b3694508b 100644 --- a/ast/visit_test.go +++ b/ast/visit_test.go @@ -298,6 +298,28 @@ fn([x, y]) = z { json.unmarshal(x, z); z > y } } } +func TestVisitorAnnotations(t *testing.T) { + + module, err := ParseModuleWithOpts("test.rego", `package test + +# METADATA +# scope: rule +p := 7`, ParserOptions{ProcessAnnotation: true}) + if err != nil { + t.Fatal(err) + } + + vis := &testVis{} + + NewGenericVisitor(vis.Visit).Walk(module) + + exp := 20 + + if len(vis.elems) != exp { + t.Fatalf("expected %d elements but got %v: %v", exp, len(vis.elems), vis.elems) + } +} + func TestWalkVars(t *testing.T) { x := MustParseBody(`x = 1; data.abc[2] = y; y[z] = [q | q = 1]`) found := NewVarSet() diff --git a/cmd/parse.go b/cmd/parse.go index 25e789feff..06b6e49681 100644 --- a/cmd/parse.go +++ b/cmd/parse.go @@ -49,7 +49,7 @@ func parse(args []string, stdout io.Writer, stderr io.Writer) int { return 0 } - result, err := loader.Rego(args[0]) + result, err := loader.RegoWithOpts(args[0], ast.ParserOptions{ProcessAnnotation: true}) switch parseParams.format.String() { case parseFormatJSON: diff --git a/loader/loader.go b/loader/loader.go index a4c07ced77..970a950edc 100644 --- a/loader/loader.go +++ b/loader/loader.go @@ -108,12 +108,12 @@ type descriptor struct { } type fileLoader struct { - metrics metrics.Metrics - bvc *bundle.VerificationConfig - skipVerify bool - descriptors []*descriptor - files map[string]bundle.FileInfo - processAnnotation bool + metrics metrics.Metrics + bvc *bundle.VerificationConfig + skipVerify bool + descriptors []*descriptor + files map[string]bundle.FileInfo + opts ast.ParserOptions } // WithMetrics provides the metrics instance to use while loading @@ -136,7 +136,7 @@ func (fl *fileLoader) WithSkipBundleVerification(skipVerify bool) FileLoader { // WithProcessAnnotation enables or disables processing of schema annotations on rules func (fl *fileLoader) WithProcessAnnotation(processAnnotation bool) FileLoader { - fl.processAnnotation = processAnnotation + fl.opts.ProcessAnnotation = processAnnotation return fl } @@ -156,7 +156,7 @@ func (fl fileLoader) Filtered(paths []string, filter Filter) (*Result, error) { return err } - result, err := loadKnownTypes(path, bs, fl.metrics, fl.processAnnotation) + result, err := loadKnownTypes(path, bs, fl.metrics, fl.opts) if err != nil { if !isUnrecognizedFile(err) { return err @@ -164,7 +164,7 @@ func (fl fileLoader) Filtered(paths []string, filter Filter) (*Result, error) { if depth > 0 { return nil } - result, err = loadFileForAnyType(path, bs, fl.metrics) + result, err = loadFileForAnyType(path, bs, fl.metrics, fl.opts) if err != nil { return err } @@ -187,7 +187,7 @@ func (fl fileLoader) AsBundle(path string) (*bundle.Bundle, error) { WithMetrics(fl.metrics). WithBundleVerificationConfig(fl.bvc). WithSkipBundleVerification(fl.skipVerify). - WithProcessAnnotations(fl.processAnnotation) + WithProcessAnnotations(fl.opts.ProcessAnnotation) // For bundle directories add the full path in front of module file names // to simplify debugging. @@ -387,8 +387,13 @@ func AllRegos(paths []string) (*Result, error) { }) } -// Rego returns a RegoFile object loaded from the given path. +// Rego is deprecated. Use RegoWithOpts instead. func Rego(path string) (*RegoFile, error) { + return RegoWithOpts(path, ast.ParserOptions{}) +} + +// RegoWithOpts returns a RegoFile object loaded from the given path. +func RegoWithOpts(path string, opts ast.ParserOptions) (*RegoFile, error) { path, err := fileurl.Clean(path) if err != nil { return nil, err @@ -397,7 +402,7 @@ func Rego(path string) (*RegoFile, error) { if err != nil { return nil, err } - return loadRego(path, bs, metrics.New()) + return loadRego(path, bs, metrics.New(), opts) } // CleanPath returns the normalized version of a path that can be used as an identifier. @@ -578,12 +583,12 @@ func allRec(path string, filter Filter, errors *Errors, loaded *Result, depth in } } -func loadKnownTypes(path string, bs []byte, m metrics.Metrics, processAnnotation bool) (interface{}, error) { +func loadKnownTypes(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (interface{}, error) { switch filepath.Ext(path) { case ".json": return loadJSON(path, bs, m) case ".rego": - return loadRego(path, bs, m, processAnnotation) + return loadRego(path, bs, m, opts) case ".yaml", ".yml": return loadYAML(path, bs, m) default: @@ -598,8 +603,8 @@ func loadKnownTypes(path string, bs []byte, m metrics.Metrics, processAnnotation return nil, unrecognizedFile(path) } -func loadFileForAnyType(path string, bs []byte, m metrics.Metrics) (interface{}, error) { - module, err := loadRego(path, bs, m) +func loadFileForAnyType(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (interface{}, error) { + module, err := loadRego(path, bs, m, opts) if err == nil { return module, nil } @@ -620,17 +625,11 @@ func loadBundleFile(path string, bs []byte, m metrics.Metrics) (bundle.Bundle, e return br.Read() } -func loadRego(path string, bs []byte, m metrics.Metrics, parserOptions ...bool) (*RegoFile, error) { +func loadRego(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (*RegoFile, error) { m.Timer(metrics.RegoModuleParse).Start() var module *ast.Module var err error - if len(parserOptions) == 1 { - module, err = ast.ParseModuleWithOpts(path, string(bs), ast.ParserOptions{ - ProcessAnnotation: parserOptions[0], - }) - } else { - module, err = ast.ParseModule(path, string(bs)) - } + module, err = ast.ParseModuleWithOpts(path, string(bs), opts) m.Timer(metrics.RegoModuleParse).Stop() if err != nil { return nil, err