diff --git a/filewriter.go b/filewriter.go index e1ee736..bf4bcf6 100644 --- a/filewriter.go +++ b/filewriter.go @@ -22,7 +22,7 @@ func WriteTo(nd Node, fpath string) error { case *Symlink: return os.Symlink(nd.Target, fpath) case File: - f, err := os.OpenFile(fpath, os.O_EXCL|os.O_CREATE|os.O_WRONLY, 0666) + f, err := createNewFile(fpath) defer f.Close() if err != nil { return err diff --git a/filewriter_unix.go b/filewriter_unix.go index 4c2f09b..1589594 100644 --- a/filewriter_unix.go +++ b/filewriter_unix.go @@ -3,10 +3,18 @@ package files -import "strings" +import ( + "os" + "strings" + "syscall" +) var invalidChars = `/` + "\x00" func isValidFilename(filename string) bool { return !strings.ContainsAny(filename, invalidChars) } + +func createNewFile(path string) (*os.File, error) { + return os.OpenFile(path, os.O_EXCL|os.O_CREATE|os.O_WRONLY|syscall.O_NOFOLLOW, 0666) +} diff --git a/filewriter_windows.go b/filewriter_windows.go index 21aae01..4392e0e 100644 --- a/filewriter_windows.go +++ b/filewriter_windows.go @@ -3,7 +3,10 @@ package files -import "strings" +import ( + "os" + "strings" +) var invalidChars = `<>:"/\|?*` + "\x00" @@ -37,3 +40,7 @@ func isValidFilename(filename string) bool { return !strings.ContainsAny(filename, invalidChars) && !isReservedName } + +func createNewFile(path string) (*os.File, error) { + return os.OpenFile(path, os.O_EXCL|os.O_CREATE|os.O_WRONLY, 0666) +}