Skip to content

Commit

Permalink
refactor: added multi-thread rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
flare committed Sep 26, 2024
1 parent 147b9a6 commit 9c29324
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 57 deletions.
151 changes: 95 additions & 56 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
"os"
"os/signal"
"path/filepath"
"sync"
"syscall"
"time"

progressbar "github.com/schollz/progressbar/v3"
"github.com/sirupsen/logrus"
Expand All @@ -24,6 +24,10 @@ const BLOCKSIZE = 128_000
// Configurations
var shuffleBytes bool
var continuous bool
var threads int
var guard chan struct{}
var wg sync.WaitGroup
var finished bool

// Runtime globals
var log *logrus.Logger
Expand Down Expand Up @@ -228,68 +232,60 @@ func ShuffleRewriteFile(path string, info os.FileInfo) (err error) {
return nil
}

func Rewrite(path string, info os.FileInfo, err error) error {
// Get inode
stat, _ := info.Sys().(*syscall.Stat_t)
inode := stat.Ino

// Return early if already completed and not continuously rewriting
if !continuous {
for _, b := range completed.CompletedFiles {
if b == path {
log.Infof("Skipping file '%s'\n", path)

// Check if inode exists
inodeExists := false
for _, i := range completed.CompletedInodes {
if i == inode {
inodeExists = true
break
}
}

// If not exists, add
if !inodeExists {
completed.CompletedInodes = append(completed.CompletedInodes, inode)
func IsCompleted(path string, inode uint64) bool {
for _, b := range completed.CompletedFiles {
if b == path {
log.Infof("Skipping file '%s'\n", path)

// Check if inode exists
inodeExists := false
for _, i := range completed.CompletedInodes {
if i == inode {
inodeExists = true
break
}
}

// Return early
return nil
// If not in CompletedInodes, add to it
if !inodeExists {
completed.CompletedInodes = append(completed.CompletedInodes, inode)
}

// Return early
return false
}
}

for _, b := range completed.CompletedInodes {
if b == inode {
log.Infof("Skipping inode '%d'\n", inode)

// Check if path exists
pathExists := false
for _, i := range completed.CompletedFiles {
if i == path {
pathExists = true
break
}
}
for _, b := range completed.CompletedInodes {
if b == inode {
log.Infof("Skipping inode '%d'\n", inode)

// If not exists, add
if !pathExists {
completed.CompletedFiles = append(completed.CompletedFiles, path)
// Check if path exists
pathExists := false
for _, i := range completed.CompletedFiles {
if i == path {
pathExists = true
break
}
}

// Return early
return nil
// If not in CompletedFiles, add to it
if !pathExists {
completed.CompletedFiles = append(completed.CompletedFiles, path)
}

// Return early
return false
}
}

// Return early if error
if err != nil {
return err
}
return true
}

// Get file info if empty
func Rewrite(path string, info os.FileInfo, err error) error {
// Call lstat() if info is nil, return if error
if info == nil {
info, err = os.Stat(path)
info, err = os.Lstat(path)
if err != nil {
return err
}
Expand All @@ -300,6 +296,20 @@ func Rewrite(path string, info os.FileInfo, err error) error {
return nil
}

// Get file inode
stat, _ := info.Sys().(*syscall.Stat_t)
inode := stat.Ino

// Return early if error
if err != nil {
return err
}

// Return early if already completed and not continuously rewriting
if !continuous && !IsCompleted(path, inode) {
return nil
}

// Rewrite file
if shuffleBytes {
if err := ShuffleRewriteFile(path, info); err != nil {
Expand All @@ -321,13 +331,29 @@ func Rewrite(path string, info os.FileInfo, err error) error {
}
saveCompleted()

// Check if signal was raised
select {
case <-done:
// Return nil
return nil
}

func RewriteRouting(path string, info os.FileInfo, err error) error {
// Start goroutine
guard <- struct{}{}
wg.Add(1)
go func() {
err = Rewrite(path, info, err)
if err != nil {
log.Errorf("Rewrite failed: %+v", err)
}
wg.Done()
<-guard
}()

// Return error if finished
if finished {
return io.EOF
case <-time.After(1):
}

// Else continue
return nil
}

Expand All @@ -344,6 +370,7 @@ func main() {
// Get arguments
flag.BoolVar(&continuous, "c", false, "continuously rewrite")
flag.BoolVar(&shuffleBytes, "s", false, "shuffle bytes on rewrite")
flag.IntVar(&threads, "t", 1, "threads")
progname := filepath.Base(os.Args[0])
flag.Usage = func() {
fmt.Fprintf(os.Stderr, `
Expand All @@ -363,18 +390,30 @@ Flags:
os.Exit(1)
}

// Ensure quit
go func() {
<-done
log.Infof("Finishing...")
finished = true
}()

// Get all files and folders
guard = make(chan struct{}, threads)
for {
err := filepath.Walk(flag.Arg(0), Rewrite)
if err == io.EOF {
log.Infof("program exited successfully")
err := filepath.Walk(flag.Arg(0), RewriteRouting)
if err == io.EOF || finished {
log.Infof("Program exited successfully")
break
} else if err != nil {
close(guard)
panic(err)
}

if !continuous {
break
}
}

// Cleanup
wg.Wait()
}
4 changes: 3 additions & 1 deletion main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func TestRewrite(t *testing.T) {
for i, size := range sizes {
// Prepare path
path := fmt.Sprintf("%s/%d", dir, i)
fmt.Printf("%s", path)

// Generate random sequence of bytes 16 megabytes
randomString := RandStringBytesMaskImprSrcUnsafe(size)
Expand Down Expand Up @@ -121,7 +122,8 @@ func TestRewrite(t *testing.T) {
assert.Equal(t, randomBytes, writtenBytes, "[step 1] written bytes != random bytes")

// Rewrite file
Rewrite(path, nil, err)
err = Rewrite(path, nil, err)
assert.NoError(t, err)

// Ensure equal
assert.Equal(t, randomBytes, writtenBytes, "[step 2] rewritten bytes != written bytes")
Expand Down

0 comments on commit 9c29324

Please sign in to comment.