Skip to content

Commit

Permalink
refactor: add a package ast (#2448)
Browse files Browse the repository at this point in the history
* refactor: add ast package

* fix: fix lint errors
  • Loading branch information
suzuki-shunsuke authored Nov 11, 2023
1 parent 6379ebb commit b97bc21
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 93 deletions.
38 changes: 38 additions & 0 deletions pkg/ast/map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package ast

import (
"errors"

"github.com/goccy/go-yaml/ast"
)

func FindMappingValueFromNode(body ast.Node, key string) (*ast.MappingValueNode, error) {
values, err := NormalizeMappingValueNodes(body)
if err != nil {
return nil, err
}
return findMappingValue(values, key), nil
}

func findMappingValue(values []*ast.MappingValueNode, key string) *ast.MappingValueNode {
for _, value := range values {
sn, ok := value.Key.(*ast.StringNode)
if !ok {
continue
}
if sn.Value == key {
return value
}
}
return nil
}

func NormalizeMappingValueNodes(node ast.Node) ([]*ast.MappingValueNode, error) {
switch t := node.(type) {
case *ast.MappingNode:
return t.Values, nil
case *ast.MappingValueNode:
return []*ast.MappingValueNode{t}, nil
}
return nil, errors.New("node must be a mapping node or mapping value node")
}
6 changes: 1 addition & 5 deletions pkg/controller/generate/output/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,4 @@ package output

import "errors"

var (
errDocumentMustBeOne = errors.New("the number of document in aqua.yaml must be one")
errBodyFormat = errors.New("fails to parse a configuration file. Format is wrong. body must be *ast.MappingNode or *ast.MappingValueNode")
errPkgsNotFound = errors.New("the field 'packages' isn't found")
)
var errDocumentMustBeOne = errors.New("the number of document in aqua.yaml must be one")
38 changes: 5 additions & 33 deletions pkg/controller/generate/output/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"

wast "github.com/aquaproj/aqua/v2/pkg/ast"
"github.com/aquaproj/aqua/v2/pkg/config/aqua"
"github.com/goccy/go-yaml"
"github.com/goccy/go-yaml/ast"
Expand Down Expand Up @@ -43,30 +44,6 @@ func (o *Outputter) generateInsert(cfgFilePath string, pkgs []*aqua.Package) err
return nil
}

func getPkgsAST(values []*ast.MappingValueNode) *ast.MappingValueNode {
for _, mapValue := range values {
key, ok := mapValue.Key.(*ast.StringNode)
if !ok {
continue
}
if key.Value != "packages" {
continue
}
return mapValue
}
return nil
}

func getMappingValueNodeFromBody(body ast.Node) []*ast.MappingValueNode {
switch b := body.(type) {
case *ast.MappingNode:
return b.Values
case *ast.MappingValueNode:
return []*ast.MappingValueNode{b}
}
return nil
}

func appendPkgsNode(mapValue *ast.MappingValueNode, node ast.Node) error {
switch mapValue.Value.Type() {
case ast.NullType:
Expand All @@ -88,15 +65,10 @@ func updateASTFile(body ast.Node, pkgs []*aqua.Package) error {
return fmt.Errorf("convert packages to node: %w", err)
}

values := getMappingValueNodeFromBody(body)
if values == nil {
return errBodyFormat
}

mapValue := getPkgsAST(values)
if mapValue == nil {
return errPkgsNotFound
values, err := wast.FindMappingValueFromNode(body, "packages")
if err != nil {
return fmt.Errorf(`find a mapping value node "packages": %w`, err)
}

return appendPkgsNode(mapValue, node)
return appendPkgsNode(values, node)
}
9 changes: 5 additions & 4 deletions pkg/controller/update/ast/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ import (
"fmt"
"strings"

wast "github.com/aquaproj/aqua/v2/pkg/ast"
"github.com/goccy/go-yaml/ast"
"github.com/sirupsen/logrus"
)

func UpdatePackages(logE *logrus.Entry, file *ast.File, newVersions map[string]string) (bool, error) {
body := file.Docs[0].Body // DocumentNode
mv, err := findMappingValueFromNode(body, "packages")
mv, err := wast.FindMappingValueFromNode(body, "packages")
if err != nil {
return false, err
return false, fmt.Errorf(`find a mapping value node "packages": %w`, err)
}

seq, ok := mv.Value.(*ast.SequenceNode)
Expand All @@ -34,9 +35,9 @@ func UpdatePackages(logE *logrus.Entry, file *ast.File, newVersions map[string]s
}

func parsePackageNode(logE *logrus.Entry, node ast.Node, newVersions map[string]string) (bool, error) { //nolint:cyclop,funlen
mvs, err := normalizeMappingValueNodes(node)
mvs, err := wast.NormalizeMappingValueNodes(node)
if err != nil {
return false, err
return false, fmt.Errorf("normalize mapping value node: %w", err)
}
var registryName string
var pkgName string
Expand Down
74 changes: 23 additions & 51 deletions pkg/controller/update/ast/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,38 @@ package ast

import (
"errors"
"fmt"

wast "github.com/aquaproj/aqua/v2/pkg/ast"
"github.com/goccy/go-yaml/ast"
"github.com/sirupsen/logrus"
)

const typeStandard = "standard"

func findMappingValue(values []*ast.MappingValueNode, key string) *ast.MappingValueNode {
for _, value := range values {
sn, ok := value.Key.(*ast.StringNode)
if !ok {
continue
}
if sn.Value == key {
return value
}
}
return nil
}
func UpdateRegistries(logE *logrus.Entry, file *ast.File, newVersions map[string]string) (bool, error) {
body := file.Docs[0].Body // DocumentNode

func normalizeMappingValueNodes(node ast.Node) ([]*ast.MappingValueNode, error) {
switch t := node.(type) {
case *ast.MappingNode:
return t.Values, nil
case *ast.MappingValueNode:
return []*ast.MappingValueNode{t}, nil
mv, err := wast.FindMappingValueFromNode(body, "registries")
if err != nil {
return false, fmt.Errorf(`find a mapping value node "registries": %w`, err)
}
return nil, errors.New("node must be a mapping node or mapping value node")
}

func findMappingValueFromNode(body ast.Node, key string) (*ast.MappingValueNode, error) {
values, err := normalizeMappingValueNodes(body)
if err != nil {
return nil, err
seq, ok := mv.Value.(*ast.SequenceNode)
if !ok {
return false, errors.New("the value must be a sequence node")
}
updated := false
for _, value := range seq.Values {
up, err := parseRegistryNode(logE, value, newVersions)
if err != nil {
return false, err
}
if up {
updated = true
}
}
return findMappingValue(values, key), nil
return updated, nil
}

func updateRegistryVersion(logE *logrus.Entry, refNode *ast.StringNode, rgstName, newVersion string) bool {
Expand All @@ -61,9 +57,9 @@ func updateRegistryVersion(logE *logrus.Entry, refNode *ast.StringNode, rgstName
}

func parseRegistryNode(logE *logrus.Entry, node ast.Node, newVersions map[string]string) (bool, error) { //nolint:gocognit,cyclop,funlen
mvs, err := normalizeMappingValueNodes(node)
mvs, err := wast.NormalizeMappingValueNodes(node)
if err != nil {
return false, err
return false, fmt.Errorf("normalize a mapping value node: %w", err)
}
var refNode *ast.StringNode
var newVersion string
Expand Down Expand Up @@ -127,27 +123,3 @@ func parseRegistryNode(logE *logrus.Entry, node ast.Node, newVersions map[string
}
return updateRegistryVersion(logE, refNode, rgstName, version), nil
}

func UpdateRegistries(logE *logrus.Entry, file *ast.File, newVersions map[string]string) (bool, error) {
body := file.Docs[0].Body // DocumentNode
mv, err := findMappingValueFromNode(body, "registries")
if err != nil {
return false, err
}

seq, ok := mv.Value.(*ast.SequenceNode)
if !ok {
return false, errors.New("the value must be a sequence node")
}
updated := false
for _, value := range seq.Values {
up, err := parseRegistryNode(logE, value, newVersions)
if err != nil {
return false, err
}
if up {
updated = true
}
}
return updated, nil
}

0 comments on commit b97bc21

Please sign in to comment.