Skip to content

Commit

Permalink
Merge pull request #459 from aduong/symlinks
Browse files Browse the repository at this point in the history
Overwrite existing dest when copying symlink and preserve link target
  • Loading branch information
priyawadhwa authored Nov 30, 2018
2 parents 7cde036 + f23cc32 commit 0b7fa58
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
8 changes: 6 additions & 2 deletions pkg/util/fs_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,12 @@ func CopySymlink(src, dest string) error {
if err != nil {
return err
}
linkDst := filepath.Join(dest, link)
return os.Symlink(linkDst, dest)
if FilepathExists(dest) {
if err := os.RemoveAll(dest); err != nil {
return err
}
}
return os.Symlink(link, dest)
}

// CopyFile copies the file at src to dest
Expand Down
60 changes: 60 additions & 0 deletions pkg/util/fs_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,63 @@ func TestExtractFile(t *testing.T) {
})
}
}

func TestCopySymlink(t *testing.T) {
type tc struct {
name string
linkTarget string
dest string
beforeLink func(r string) error
}

tcs := []tc{{
name: "absolute symlink",
linkTarget: "/abs/dest",
}, {
name: "relative symlink",
linkTarget: "rel",
}, {
name: "symlink copy overwrites existing file",
linkTarget: "/abs/dest",
dest: "overwrite_me",
beforeLink: func(r string) error {
return ioutil.WriteFile(filepath.Join(r, "overwrite_me"), nil, 0644)
},
}}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
tc := tc
t.Parallel()
r, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(r)

if tc.beforeLink != nil {
if err := tc.beforeLink(r); err != nil {
t.Fatal(err)
}
}
link := filepath.Join(r, "link")
dest := filepath.Join(r, "copy")
if tc.dest != "" {
dest = filepath.Join(r, tc.dest)
}
if err := os.Symlink(tc.linkTarget, link); err != nil {
t.Fatal(err)
}
if err := CopySymlink(link, dest); err != nil {
t.Fatal(err)
}
got, err := os.Readlink(dest)
if err != nil {
t.Fatalf("error reading link %s: %s", link, err)
}
if got != tc.linkTarget {
t.Errorf("link target does not match: %s != %s", got, tc.linkTarget)
}
})
}
}

0 comments on commit 0b7fa58

Please sign in to comment.