diff --git a/examples/go-sftp-server/main.go b/examples/go-sftp-server/main.go index ba902b6f..aef436cb 100644 --- a/examples/go-sftp-server/main.go +++ b/examples/go-sftp-server/main.go @@ -20,10 +20,13 @@ func main() { var ( readOnly bool debugStderr bool + winRoot bool ) flag.BoolVar(&readOnly, "R", false, "read-only server") flag.BoolVar(&debugStderr, "e", false, "debug to stderr") + flag.BoolVar(&winRoot, "wr", false, "windows root") + flag.Parse() debugStream := io.Discard @@ -128,6 +131,11 @@ func main() { fmt.Fprintf(debugStream, "Read write server\n") } + if winRoot { + serverOptions = append(serverOptions, sftp.WindowsRootEnumeratesDrives()) + fmt.Fprintf(debugStream, "Windows root enabled\n") + } + server, err := sftp.NewServer( channel, serverOptions..., diff --git a/go.mod b/go.mod index 0a2de367..ac5ee715 100644 --- a/go.mod +++ b/go.mod @@ -6,4 +6,5 @@ require ( github.com/kr/fs v0.1.0 github.com/stretchr/testify v1.8.0 golang.org/x/crypto v0.31.0 + golang.org/x/sys v0.28.0 // indirect ) diff --git a/server.go b/server.go index fb474c4f..cd656d8f 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "io/fs" "io/ioutil" "os" "path/filepath" @@ -21,6 +22,18 @@ const ( SftpServerWorkerCount = 8 ) +type file interface { + Stat() (os.FileInfo, error) + ReadAt(b []byte, off int64) (int, error) + WriteAt(b []byte, off int64) (int, error) + Readdir(int) ([]os.FileInfo, error) + Name() string + Truncate(int64) error + Chmod(mode fs.FileMode) error + Chown(uid, gid int) error + Close() error +} + // Server is an SSH File Transfer Protocol (sftp) server. // This is intended to provide the sftp subsystem to an ssh server daemon. // This implementation currently supports most of sftp server protocol version 3, @@ -30,14 +43,15 @@ type Server struct { debugStream io.Writer readOnly bool pktMgr *packetManager - openFiles map[string]*os.File + openFiles map[string]file openFilesLock sync.RWMutex handleCount int workDir string + winRoot bool maxTxPacket uint32 } -func (svr *Server) nextHandle(f *os.File) string { +func (svr *Server) nextHandle(f file) string { svr.openFilesLock.Lock() defer svr.openFilesLock.Unlock() svr.handleCount++ @@ -57,7 +71,7 @@ func (svr *Server) closeHandle(handle string) error { return EBADF } -func (svr *Server) getHandle(handle string) (*os.File, bool) { +func (svr *Server) getHandle(handle string) (file, bool) { svr.openFilesLock.RLock() defer svr.openFilesLock.RUnlock() f, ok := svr.openFiles[handle] @@ -86,7 +100,7 @@ func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error) serverConn: svrConn, debugStream: ioutil.Discard, pktMgr: newPktMgr(svrConn), - openFiles: make(map[string]*os.File), + openFiles: make(map[string]file), maxTxPacket: defaultMaxTxPacket, } @@ -118,6 +132,14 @@ func ReadOnly() ServerOption { } } +// WindowsRootEnumeratesDrives configures a Server to serve a virtual '/' for windows that lists all drives +func WindowsRootEnumeratesDrives() ServerOption { + return func(s *Server) error { + s.winRoot = true + return nil + } +} + // WithAllocator enable the allocator. // After processing a packet we keep in memory the allocated slices // and we reuse them for new packets. @@ -215,7 +237,7 @@ func handlePacket(s *Server, p orderedRequest) error { } case *sshFxpLstatPacket: // stat the requested file - info, err := os.Lstat(s.toLocalPath(p.Path)) + info, err := s.lstat(s.toLocalPath(p.Path)) rpkt = &sshFxpStatResponse{ ID: p.ID, info: info, @@ -289,7 +311,7 @@ func handlePacket(s *Server, p orderedRequest) error { case *sshFxpOpendirPacket: lp := s.toLocalPath(p.Path) - if stat, err := os.Stat(lp); err != nil { + if stat, err := s.stat(lp); err != nil { rpkt = statusFromError(p.ID, err) } else if !stat.IsDir() { rpkt = statusFromError(p.ID, &os.PathError{ @@ -493,7 +515,7 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { mode = fs.FileMode() & os.ModePerm } - f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode) + f, err := svr.openfile(svr.toLocalPath(p.Path), osFlags, mode) if err != nil { return statusFromError(p.ID, err) } diff --git a/server_posix.go b/server_posix.go new file mode 100644 index 00000000..c07d70a0 --- /dev/null +++ b/server_posix.go @@ -0,0 +1,21 @@ +//go:build !windows +// +build !windows + +package sftp + +import ( + "io/fs" + "os" +) + +func (s *Server) openfile(path string, flag int, mode fs.FileMode) (file, error) { + return os.OpenFile(path, flag, mode) +} + +func (s *Server) lstat(name string) (os.FileInfo, error) { + return os.Lstat(name) +} + +func (s *Server) stat(name string) (os.FileInfo, error) { + return os.Stat(name) +} diff --git a/server_windows.go b/server_windows.go index b35be730..e940dba1 100644 --- a/server_windows.go +++ b/server_windows.go @@ -1,8 +1,15 @@ package sftp import ( + "fmt" + "io" + "io/fs" + "os" "path" "path/filepath" + "time" + + "golang.org/x/sys/windows" ) func (s *Server) toLocalPath(p string) string { @@ -12,7 +19,11 @@ func (s *Server) toLocalPath(p string) string { lp := filepath.FromSlash(p) - if path.IsAbs(p) { + if path.IsAbs(p) { // starts with '/' + if len(p) == 1 && s.winRoot { + return `\\.\` // for openfile + } + tmp := lp for len(tmp) > 0 && tmp[0] == '\\' { tmp = tmp[1:] @@ -33,7 +44,150 @@ func (s *Server) toLocalPath(p string) string { // e.g. "/C:" to "C:\\" return tmp } + + if s.winRoot { + // Make it so that "/Windows" is not found, and "/c:/Windows" has to be used + return `\\.\` + tmp + } } return lp } + +func bitsToDrives(bitmap uint32) []string { + var drive rune = 'a' + var drives []string + + for bitmap != 0 && drive <= 'z' { + if bitmap&1 == 1 { + drives = append(drives, string(drive)+":") + } + drive++ + bitmap >>= 1 + } + + return drives +} + +func getDrives() ([]string, error) { + mask, err := windows.GetLogicalDrives() + if err != nil { + return nil, fmt.Errorf("GetLogicalDrives: %w", err) + } + return bitsToDrives(mask), nil +} + +type driveInfo struct { + fs.FileInfo + name string +} + +func (i *driveInfo) Name() string { + return i.name // since the Name() returned from a os.Stat("C:\\") is "\\" +} + +type winRoot struct { + drives []string +} + +func newWinRoot() (*winRoot, error) { + drives, err := getDrives() + if err != nil { + return nil, err + } + return &winRoot{ + drives: drives, + }, nil +} + +func (f *winRoot) Readdir(n int) ([]os.FileInfo, error) { + drives := f.drives + if n > 0 && len(drives) > n { + drives = drives[:n] + } + f.drives = f.drives[len(drives):] + if len(drives) == 0 { + return nil, io.EOF + } + + var infos []os.FileInfo + for _, drive := range drives { + fi, err := os.Stat(drive + `\`) + if err != nil { + return nil, err + } + + di := &driveInfo{ + FileInfo: fi, + name: drive, + } + infos = append(infos, di) + } + + return infos, nil +} + +func (f *winRoot) Stat() (os.FileInfo, error) { + return rootFileInfo, nil +} +func (f *winRoot) ReadAt(b []byte, off int64) (int, error) { + return 0, os.ErrPermission +} +func (f *winRoot) WriteAt(b []byte, off int64) (int, error) { + return 0, os.ErrPermission +} +func (f *winRoot) Name() string { + return "/" +} +func (f *winRoot) Truncate(int64) error { + return os.ErrPermission +} +func (f *winRoot) Chmod(mode fs.FileMode) error { + return os.ErrPermission +} +func (f *winRoot) Chown(uid, gid int) error { + return os.ErrPermission +} +func (f *winRoot) Close() error { + f.drives = nil + return nil +} + +func (s *Server) openfile(path string, flag int, mode fs.FileMode) (file, error) { + if path == `\\.\` && s.winRoot { + return newWinRoot() + } + return os.OpenFile(path, flag, mode) +} + +type winRootFileInfo struct { + name string + modTime time.Time +} + +func (w *winRootFileInfo) Name() string { return w.name } +func (w *winRootFileInfo) Size() int64 { return 0 } +func (w *winRootFileInfo) Mode() fs.FileMode { return fs.ModeDir | 0555 } // read+execute for all +func (w *winRootFileInfo) ModTime() time.Time { return w.modTime } +func (w *winRootFileInfo) IsDir() bool { return true } +func (w *winRootFileInfo) Sys() interface{} { return nil } + +// Create a new root FileInfo +var rootFileInfo = &winRootFileInfo{ + name: "/", + modTime: time.Now(), +} + +func (s *Server) lstat(name string) (os.FileInfo, error) { + if name == `\\.\` && s.winRoot { + return rootFileInfo, nil + } + return os.Lstat(name) +} + +func (s *Server) stat(name string) (os.FileInfo, error) { + if name == `\\.\` && s.winRoot { + return rootFileInfo, nil + } + return os.Stat(name) +}