Skip to content

Commit

Permalink
error handling & test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
soerenkoehler committed Apr 1, 2024
1 parent 176d463 commit 1ef2804
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 126 deletions.
46 changes: 24 additions & 22 deletions chdiff/chdiff.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,32 +88,34 @@ func (cmd *CmdCreate) Run(deps ChdiffDependencies) error {
}

func (cmd *CmdVerify) Run(deps ChdiffDependencies) error {
oldDigest, err := deps.DigestRead(
cmd.RootPath,
defaultDigestFile(cmd.cmdDigest))

if err != nil {
util.Error(err.Error())
return err
}
var chain util.ChainContext
var oldDigest digest.Digest

deps.DiffPrint(
deps.Stdout(),
deps.DigestCompare(
oldDigest,
deps.DigestCalculate(
cmd.RootPath,
oldDigest.Algorithm)))

return nil
chain.Chain(func() {
oldDigest, chain.Err = deps.DigestRead(
cmd.RootPath,
defaultDigestFile(cmd.cmdDigest))
}).Chain(func() {
deps.DiffPrint(
deps.Stdout(),
deps.DigestCompare(
oldDigest,
deps.DigestCalculate(
cmd.RootPath,
oldDigest.Algorithm)))
}).ChainError("verify")

return chain.Err
}

func loadConfig() {
if err := json.Unmarshal(readConfigFile(), &common.Config); err != nil {
util.Fatal("reading config: %s", err.Error())
}
util.SetLogLevelByName(common.Config.LogLevel)
util.Debug("%+v", common.Config)
chain := &util.ChainContext{}
chain.Chain(func() {
chain.Err = json.Unmarshal(readConfigFile(), &common.Config)
}).Chain(func() {
util.SetLogLevelByName(common.Config.LogLevel)
util.Debug("%+v", common.Config)
}).ChainFatal("reading config")
}

func readConfigFile() []byte {
Expand Down
4 changes: 2 additions & 2 deletions chdiff/chdiff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,12 @@ func (s *TSChdiff) TestDigestVerifyMissingDigestFile() {
On("exit", mock.Anything).Return().
On("DigestRead", absDataPath, absDigestFile).Return(
digest.Digest{},
fmt.Errorf("read error"))
fmt.Errorf("no such file"))

chdiff.Chdiff("TEST", []string{"", "v", "x"}, s.Dependencies)

s.Dependencies.AssertExpectations(s.T())
assert.Contains(s.T(), s.Stderr.String(), "[E] read error")
assert.Contains(s.T(), s.Stderr.String(), "[E] verify: no such file")
}

func (s *TSChdiff) TestDigestCreateSHA256DefaultName() {
Expand Down
129 changes: 59 additions & 70 deletions digest/calculator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ import (
"encoding/hex"
"hash"
"io"
"io/fs"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -47,13 +46,15 @@ func Calculate(
go func() {
defer close(context.digest)

absPath, err := filepath.Abs(context.rootPath)
if err != nil {
util.Fatal(err.Error())
}
var absPath string

context.processPath(absPath)
context.waitGroup.Wait()
chain := &util.ChainContext{}
chain.Chain(func() {
absPath, chain.Err = filepath.Abs(context.rootPath)
}).Chain(func() {
context.processPath(absPath)
context.waitGroup.Wait()
}).ChainFatal("calculate")
}()

result := NewDigest(rootPath, time.Now())
Expand Down Expand Up @@ -85,51 +86,54 @@ func (context digestContext) processPath(path string) {
}

func (context digestContext) processDir(dir string) {
entries, err := os.ReadDir(dir)
if err != nil {
util.Error(err.Error())
return
}

for _, entry := range entries {
context.processPath(filepath.Join(dir, entry.Name()))
}
var entries []fs.DirEntry

chain := &util.ChainContext{}
chain.Chain(func() {
entries, chain.Err = os.ReadDir(dir)
}).Chain(func() {
for _, entry := range entries {
context.processPath(filepath.Join(dir, entry.Name()))
}
}).ChainError("process dir")
}

func (context digestContext) processFile(file string) {
chain(func() (string, error){
return filepath.Rel(context.rootPath, file)
}, func(file))
if err != nil {
util.Error(err.Error())
return
}

input, err := os.Open(file)
if err != nil {
util.Error(err.Error())
return
}

defer input.Close()

hash := getNewHash(context.algorithm)
io.Copy(hash, input)

context.digest <- digestEntry{
file: relativePath,
hash: hex.EncodeToString(hash.Sum(nil)),
}
var relativePath string
var input *os.File

chain := &util.ChainContext{}
chain.Chain(func() {
relativePath, chain.Err = filepath.Rel(context.rootPath, file)
}).Chain(func() {
input, chain.Err = os.Open(file)
}).Chain(func() {
defer input.Close()

hash := getNewHash(context.algorithm)
io.Copy(hash, input)

context.digest <- digestEntry{
file: relativePath,
hash: hex.EncodeToString(hash.Sum(nil)),
}
}).ChainError("process file")
}

func (context digestContext) pathExcluded(path string) bool {
return chain(func() (string, error) {
return filepath.Rel(context.rootPath, path)
}, func(relPath string) bool {
return matchAnyPattern(path, common.Config.Exclude.Absolute) ||
matchAnyPattern(relPath, common.Config.Exclude.Relative) ||
matchAnyPattern(filepath.Base(relPath), common.Config.Exclude.Anywhere)
}, false)
var relativePath string
var result bool

chain := &util.ChainContext{}
chain.Chain(func() {
relativePath, chain.Err = filepath.Rel(context.rootPath, path)
}).Chain(func() {
result = matchAnyPattern(path, common.Config.Exclude.Absolute) ||
matchAnyPattern(relativePath, common.Config.Exclude.Relative) ||
matchAnyPattern(filepath.Base(relativePath), common.Config.Exclude.Anywhere)
}).ChainError("filter path")

return result
}

func matchAnyPattern(path string, patterns []string) bool {
Expand All @@ -142,9 +146,14 @@ func matchAnyPattern(path string, patterns []string) bool {
}

func matchPattern(path, pattern string) bool {
return chain(func() (bool, error) {
return filepath.Match(pattern, path)
}, identity[bool], false)
var chain util.ChainContext
var result bool

chain.Chain(func() {
result, chain.Err = filepath.Match(pattern, path)
}).ChainError("match path")

return result
}

func getNewHash(algorithm HashType) hash.Hash {
Expand All @@ -157,23 +166,3 @@ func getNewHash(algorithm HashType) hash.Hash {
return sha256.New()
}
}

func chain[T any](err error, errVal T, f ...func()) {
while() {

}
}

func chainx[T, U any](f func() (T, error), g func(in T) U, errVal U) U {
first_result, err := f()
if err == nil {
return g(first_result)
} else {
util.Error(err.Error())
return errVal
}
}

func identity[T any](in T) T {
return in
}
78 changes: 46 additions & 32 deletions digest/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package digest
import (
"bufio"
"fmt"
"io/fs"
"os"
"strings"

"github.com/soerenkoehler/go-chdiff/util"
)

const SEPARATOR_TEXT = " "
Expand All @@ -15,48 +18,59 @@ type Reader func(digestRootPath, digestFile string) (Digest, error)
type Writer func(digest Digest, digestFile string) error

func Load(digestPath, digestFile string) (Digest, error) {
var fileInfo fs.FileInfo
var input *os.File
var digest Digest

digestFileInfo, err := os.Lstat(digestFile)
if err != nil {
// hack for better error message under Windows
err.(*os.PathError).Op = "lstat"
return Digest{}, err
}

digest := NewDigest(digestPath, digestFileInfo.ModTime().Local())
chain := &util.ChainContext{}

input, err := os.Open(digestFile)
if err != nil {
return Digest{}, err
}
chain.Chain(func() {
fileInfo, chain.Err = os.Lstat(digestFile)
hackLstatErrorForWindow(&chain.Err)
}).Chain(func() {
input, chain.Err = os.Open(digestFile)
}).Chain(func() {
defer input.Close()

defer input.Close()
digest = NewDigest(digestPath, fileInfo.ModTime().Local())

lines := bufio.NewScanner(input)
for lines.Scan() {
normalized := strings.Replace(lines.Text(), SEPARATOR_TEXT, SEPARATOR_BINARY, 1)
tokens := strings.SplitN(normalized, SEPARATOR_BINARY, 2)
if len(tokens) != 2 {
return Digest{}, fmt.Errorf("invalid digest file")
lines := bufio.NewScanner(input)
for lines.Scan() {
normalized := strings.Replace(lines.Text(), SEPARATOR_TEXT, SEPARATOR_BINARY, 1)
tokens := strings.SplitN(normalized, SEPARATOR_BINARY, 2)
if len(tokens) != 2 {
chain.Err = fmt.Errorf("invalid digest file")
return
}
digest.AddFileHash(tokens[1], tokens[0])
}
digest.AddFileHash(tokens[1], tokens[0])
}
})

return digest, nil
return digest, chain.Err
}

func Save(digest Digest, digestFile string) error {
output, err := os.Create(digestFile)
if err != nil {
return err
// hack for better error message under Windows
func hackLstatErrorForWindow(err *error) {
if *err != nil {
(*err).(*os.PathError).Op = "lstat"
}
}

defer output.Close()
func Save(digest Digest, digestFile string) error {
var output *os.File

for k, v := range *digest.Entries {
fmt.Fprintf(output, "%v%v%v\n", v, SEPARATOR_BINARY, k)
}
os.Chtimes(digestFile, digest.Location.Time, digest.Location.Time)
chain := &util.ChainContext{}

chain.Chain(func() {
output, chain.Err = os.Create(digestFile)
}).Chain(func() {
defer output.Close()

for k, v := range *digest.Entries {
fmt.Fprintf(output, "%v%v%v\n", v, SEPARATOR_BINARY, k)
}
os.Chtimes(digestFile, digest.Location.Time, digest.Location.Time)
})

return nil
return chain.Err
}
24 changes: 24 additions & 0 deletions util/chain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package util

type ChainContext struct {
Err error
}

func (ctx *ChainContext) Chain(f func()) *ChainContext {
if ctx.Err == nil {
f()
}
return ctx
}

func (ctx *ChainContext) ChainError(label string) {
if ctx.Err != nil {
Error("%s: %s", label, ctx.Err.Error())
}
}

func (ctx *ChainContext) ChainFatal(label string) {
if ctx.Err != nil {
Fatal("%s: %s", label, ctx.Err.Error())
}
}

0 comments on commit 1ef2804

Please sign in to comment.