From d1903fbd460e9a8105bae72fcdf492a4999b4cee Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 19 Jan 2024 00:20:23 +0000 Subject: [PATCH 01/12] rework client to prevent after-close usage, and support perm at open --- attrs.go | 19 ++++- client.go | 156 +++++++++++++++++++++++++++++++------ client_test.go | 2 +- packet.go | 93 ++++++++++++++++++---- packet_test.go | 29 +++++-- request-attrs.go | 6 -- server.go | 87 +++++++++------------ server_integration_test.go | 17 ++-- server_test.go | 82 ++++++++++++++++++- 9 files changed, 381 insertions(+), 110 deletions(-) diff --git a/attrs.go b/attrs.go index 758cd4ff..74ac03b7 100644 --- a/attrs.go +++ b/attrs.go @@ -32,10 +32,10 @@ func (fi *fileInfo) Name() string { return fi.name } func (fi *fileInfo) Size() int64 { return int64(fi.stat.Size) } // Mode returns file mode bits. -func (fi *fileInfo) Mode() os.FileMode { return toFileMode(fi.stat.Mode) } +func (fi *fileInfo) Mode() os.FileMode { return fi.stat.FileMode() } // ModTime returns the last modification time of the file. -func (fi *fileInfo) ModTime() time.Time { return time.Unix(int64(fi.stat.Mtime), 0) } +func (fi *fileInfo) ModTime() time.Time { return fi.stat.ModTime() } // IsDir returns true if the file is a directory. func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() } @@ -56,6 +56,21 @@ type FileStat struct { Extended []StatExtended } +// ModTime returns the Mtime SFTP file attribute converted to a time.Time +func (fs *FileStat) ModTime() time.Time { + return time.Unix(int64(fs.Mtime), 0) +} + +// AccessTime returns the Atime SFTP file attribute converted to a time.Time +func (fs *FileStat) AccessTime() time.Time { + return time.Unix(int64(fs.Atime), 0) +} + +// FileMode returns the Mode SFTP file attribute converted to an os.FileMode +func (fs *FileStat) FileMode() os.FileMode { + return toFileMode(fs.Mode) +} + // StatExtended contains additional, extended information for a FileStat. type StatExtended struct { ExtType string diff --git a/client.go b/client.go index a3b8e22b..1d55aaea 100644 --- a/client.go +++ b/client.go @@ -257,7 +257,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie // read/write at the same time. For those services you will need to use // `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`. func (c *Client) Create(path string) (*File, error) { - return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) + return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) } const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt @@ -510,7 +510,7 @@ func (c *Client) Symlink(oldname, newname string) error { } } -func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error { +func (c *Client) fsetstat(handle string, flags uint32, attrs interface{}) error { id := c.nextID() typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{ ID: id, @@ -590,14 +590,14 @@ func (c *Client) Truncate(path string, size int64) error { // returned file can be used for reading; the associated file descriptor // has mode O_RDONLY. func (c *Client) Open(path string) (*File, error) { - return c.open(path, flags(os.O_RDONLY)) + return c.open(path, toPflags(os.O_RDONLY)) } // OpenFile is the generalized open call; most users will use Open or // Create instead. It opens the named file with specified flag (O_RDONLY // etc.). If successful, methods on the returned File can be used for I/O. func (c *Client) OpenFile(path string, f int) (*File, error) { - return c.open(path, flags(f)) + return c.open(path, toPflags(f)) } func (c *Client) open(path string, pflags uint32) (*File, error) { @@ -976,16 +976,26 @@ func (c *Client) RemoveAll(path string) error { type File struct { c *Client path string - handle string - mu sync.Mutex + mu sync.RWMutex + handle string offset int64 // current offset within remote file } // Close closes the File, rendering it unusable for I/O. It returns an // error, if any. func (f *File) Close() error { - return f.c.close(f.handle) + f.mu.Lock() + defer f.mu.Unlock() + + if f.handle == "" { + return os.ErrClosed + } + + handle := f.handle + f.handle = "" + + return f.c.close(handle) } // Name returns the name of the file as presented to Open or Create. @@ -1006,7 +1016,11 @@ func (f *File) Read(b []byte) (int, error) { f.mu.Lock() defer f.mu.Unlock() - n, err := f.ReadAt(b, f.offset) + if f.handle == "" { + return 0, os.ErrClosed + } + + n, err := f.readAt(b, f.offset) f.offset += int64(n) return n, err } @@ -1071,6 +1085,17 @@ func (f *File) readAtSequential(b []byte, off int64) (read int, err error) { // the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics, // so the file offset is not altered during the read. func (f *File) ReadAt(b []byte, off int64) (int, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return 0, os.ErrClosed + } + + return f.readAt(b, off) +} + +func (f *File) readAt(b []byte, off int64) (int, error) { if len(b) <= f.c.maxPacket { // This should be able to be serviced with 1/2 requests. // So, just do it directly. @@ -1267,6 +1292,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { f.mu.Lock() defer f.mu.Unlock() + if f.handle == "" { + return 0, os.ErrClosed + } + if f.c.disableConcurrentReads { return f.writeToSequential(w) } @@ -1456,9 +1485,20 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { } } +func (f *File) Stat() (os.FileInfo, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return nil, os.ErrClosed + } + + return f.stat() +} + // Stat returns the FileInfo structure describing file. If there is an // error. -func (f *File) Stat() (os.FileInfo, error) { +func (f *File) stat() (os.FileInfo, error) { fs, err := f.c.fstat(f.handle) if err != nil { return nil, err @@ -1478,7 +1518,11 @@ func (f *File) Write(b []byte) (int, error) { f.mu.Lock() defer f.mu.Unlock() - n, err := f.WriteAt(b, f.offset) + if f.handle == "" { + return 0, os.ErrClosed + } + + n, err := f.writeAt(b, f.offset) f.offset += int64(n) return n, err } @@ -1636,6 +1680,17 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) { // the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics, // so the file offset is not altered during the write. func (f *File) WriteAt(b []byte, off int64) (written int, err error) { + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return 0, os.ErrClosed + } + + return f.writeAt(b, off) +} + +func (f *File) writeAt(b []byte, off int64) (written int, err error) { if len(b) <= f.c.maxPacket { // We can do this in one write. return f.writeChunkAt(nil, b, off) @@ -1675,6 +1730,17 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) { // // Otherwise, the given concurrency will be capped by the Client's max concurrency. func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.handle == "" { + return 0, os.ErrClosed + } + + return f.readFromWithConcurrency(r, concurrency) +} + +func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { // Split the write into multiple maxPacket sized concurrent writes. // This allows writes with a suitably large reader // to transfer data at a much faster rate due to overlapping round trip times. @@ -1824,6 +1890,10 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { f.mu.Lock() defer f.mu.Unlock() + if f.handle == "" { + return 0, os.ErrClosed + } + if f.c.useConcurrentWrites { var remain int64 switch r := r.(type) { @@ -1845,7 +1915,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { if remain < 0 { // We can strongly assert that we want default max concurrency here. - return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests) + return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests) } if remain > int64(f.c.maxPacket) { @@ -1860,7 +1930,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { concurrency64 = int64(f.c.maxConcurrentRequests) } - return f.ReadFromWithConcurrency(r, int(concurrency64)) + return f.readFromWithConcurrency(r, int(concurrency64)) } } @@ -1903,12 +1973,16 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { f.mu.Lock() defer f.mu.Unlock() + if f.handle == "" { + return 0, os.ErrClosed + } + switch whence { case io.SeekStart: case io.SeekCurrent: offset += f.offset case io.SeekEnd: - fi, err := f.Stat() + fi, err := f.stat() if err != nil { return f.offset, err } @@ -1927,20 +2001,61 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { // Chown changes the uid/gid of the current file. func (f *File) Chown(uid, gid int) error { - return f.c.Chown(f.path, uid, gid) + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return os.ErrClosed + } + + return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{ + UID: uint32(uid), + GID: uint32(gid), + }) } // Chmod changes the permissions of the current file. // // See Client.Chmod for details. func (f *File) Chmod(mode os.FileMode) error { - return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode)) + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return os.ErrClosed + } + + return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode)) +} + +// Truncate sets the size of the current file. Although it may be safely assumed +// that if the size is less than its current size it will be truncated to fit, +// the SFTP protocol does not specify what behavior the server should do when setting +// size greater than the current size. +// We send a SSH_FXP_FSETSTAT here since we have a file handle +func (f *File) Truncate(size int64) error { + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return os.ErrClosed + } + + return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size)) } // Sync requests a flush of the contents of a File to stable storage. // // Sync requires the server to support the fsync@openssh.com extension. func (f *File) Sync() error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.handle == "" { + return os.ErrClosed + } + + id := f.c.nextID() typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{ ID: id, @@ -1957,15 +2072,6 @@ func (f *File) Sync() error { } } -// Truncate sets the size of the current file. Although it may be safely assumed -// that if the size is less than its current size it will be truncated to fit, -// the SFTP protocol does not specify what behavior the server should do when setting -// size greater than the current size. -// We send a SSH_FXP_FSETSTAT here since we have a file handle -func (f *File) Truncate(size int64) error { - return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size)) -} - // normaliseError normalises an error into a more standard form that can be // checked against stdlib errors like io.EOF or os.ErrNotExist. func normaliseError(err error) error { @@ -1990,7 +2096,7 @@ func normaliseError(err error) error { // flags converts the flags passed to OpenFile into ssh flags. // Unsupported flags are ignored. -func flags(f int) uint32 { +func toPflags(f int) uint32 { var out uint32 switch f & os.O_WRONLY { case os.O_WRONLY: diff --git a/client_test.go b/client_test.go index 4577ca22..dda8af2b 100644 --- a/client_test.go +++ b/client_test.go @@ -81,7 +81,7 @@ var flagsTests = []struct { func TestFlags(t *testing.T) { for i, tt := range flagsTests { - got := flags(tt.flags) + got := toPflags(tt.flags) if got != tt.want { t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got) } diff --git a/packet.go b/packet.go index 1232ff1e..2fea2bef 100644 --- a/packet.go +++ b/packet.go @@ -56,6 +56,11 @@ func marshalFileInfo(b []byte, fi os.FileInfo) []byte { flags, fileStat := fileStatFromInfo(fi) b = marshalUint32(b, flags) + + return marshalFileStat(b, flags, fileStat) +} + +func marshalFileStat(b []byte, flags uint32, fileStat *FileStat) []byte { if flags&sshFileXferAttrSize != 0 { b = marshalUint64(b, fileStat.Size) } @@ -91,10 +96,9 @@ func marshalStatus(b []byte, err StatusError) []byte { } func marshal(b []byte, v interface{}) []byte { - if v == nil { - return b - } switch v := v.(type) { + case nil: + return b case uint8: return append(b, v) case uint32: @@ -103,6 +107,8 @@ func marshal(b []byte, v interface{}) []byte { return marshalUint64(b, v) case string: return marshalString(b, v) + case []byte: + return append(b, v...) case os.FileInfo: return marshalFileInfo(b, v) default: @@ -180,8 +186,6 @@ func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) { } if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { fs.UID, b, _ = unmarshalUint32Safe(b) - } - if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { fs.GID, b, _ = unmarshalUint32Safe(b) } if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions { @@ -681,12 +685,13 @@ type sshFxpOpenPacket struct { ID uint32 Path string Pflags uint32 - Flags uint32 // ignored + Flags uint32 + Attrs interface{} } func (p *sshFxpOpenPacket) id() uint32 { return p.ID } -func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { +func (p *sshFxpOpenPacket) marshalPacket() ([]byte, []byte, error) { l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Path) + 4 + 4 @@ -698,7 +703,20 @@ func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { b = marshalUint32(b, p.Pflags) b = marshalUint32(b, p.Flags) - return b, nil + switch attrs := p.Attrs.(type) { + case os.FileInfo: + _, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet. + return b, marshalFileStat(nil, p.Flags, fs), nil + case *FileStat: + return b, marshalFileStat(nil, p.Flags, attrs), nil + } + + return b, marshal(nil, p.Attrs), nil +} + +func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err } func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { @@ -709,12 +727,25 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { return err } else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil { return err - } else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil { + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { return err } + p.Attrs = b return nil } +func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) *FileStat { + switch attrs := p.Attrs.(type) { + case *FileStat: + return attrs + case []byte: + fs, _ := unmarshalFileStat(flags, attrs) + return fs + default: + panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs)) + } +} + type sshFxpReadPacket struct { ID uint32 Len uint32 @@ -943,9 +974,15 @@ func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) { b = marshalString(b, p.Path) b = marshalUint32(b, p.Flags) - payload := marshal(nil, p.Attrs) + switch attrs := p.Attrs.(type) { + case os.FileInfo: + _, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet. + return b, marshalFileStat(nil, p.Flags, fs), nil + case *FileStat: + return b, marshalFileStat(nil, p.Flags, attrs), nil + } - return b, payload, nil + return b, marshal(nil, p.Attrs), nil } func (p *sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { @@ -964,9 +1001,15 @@ func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) { b = marshalString(b, p.Handle) b = marshalUint32(b, p.Flags) - payload := marshal(nil, p.Attrs) + switch attrs := p.Attrs.(type) { + case os.FileInfo: + _, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet. + return b, marshalFileStat(nil, p.Flags, fs), nil + case *FileStat: + return b, marshalFileStat(nil, p.Flags, attrs), nil + } - return b, payload, nil + return b, marshal(nil, p.Attrs), nil } func (p *sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) { @@ -987,6 +1030,18 @@ func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error { return nil } +func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) *FileStat { + switch attrs := p.Attrs.(type) { + case *FileStat: + return attrs + case []byte: + fs, _ := unmarshalFileStat(flags, attrs) + return fs + default: + panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs)) + } +} + func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error { var err error if p.ID, b, err = unmarshalUint32Safe(b); err != nil { @@ -1000,6 +1055,18 @@ func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error { return nil } +func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) *FileStat { + switch attrs := p.Attrs.(type) { + case *FileStat: + return attrs + case []byte: + fs, _ := unmarshalFileStat(flags, attrs) + return fs + default: + panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs)) + } +} + type sshFxpHandlePacket struct { ID uint32 Handle string diff --git a/packet_test.go b/packet_test.go index cbee5e4c..6278ca4f 100644 --- a/packet_test.go +++ b/packet_test.go @@ -376,7 +376,7 @@ func TestSendPacket(t *testing.T) { packet: &sshFxpOpenPacket{ ID: 1, Path: "/foo", - Pflags: flags(os.O_RDONLY), + Pflags: toPflags(os.O_RDONLY), }, want: []byte{ 0x0, 0x0, 0x0, 0x15, @@ -387,6 +387,26 @@ func TestSendPacket(t *testing.T) { 0x0, 0x0, 0x0, 0x0, }, }, + { + packet: &sshFxpOpenPacket{ + ID: 3, + Path: "/foo", + Pflags: toPflags(os.O_WRONLY | os.O_CREATE | os.O_TRUNC), + Flags: sshFileXferAttrPermissions, + Attrs: &FileStat{ + Mode: 0o755, + }, + }, + want: []byte{ + 0x0, 0x0, 0x0, 0x19, + 0x3, + 0x0, 0x0, 0x0, 0x3, + 0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o', + 0x0, 0x0, 0x0, 0x1a, + 0x0, 0x0, 0x0, 0x4, + 0x0, 0x0, 0x1, 0xed, + }, + }, { packet: &sshFxpWritePacket{ ID: 124, @@ -409,10 +429,7 @@ func TestSendPacket(t *testing.T) { ID: 31, Path: "/bar", Flags: sshFileXferAttrUIDGID, - Attrs: struct { - UID uint32 - GID uint32 - }{ + Attrs: &FileStat{ UID: 1000, GID: 100, }, @@ -611,7 +628,7 @@ func BenchmarkMarshalOpen(b *testing.B) { benchMarshal(b, &sshFxpOpenPacket{ ID: 1, Path: "/home/test/some/random/path", - Pflags: flags(os.O_RDONLY), + Pflags: toPflags(os.O_RDONLY), }) } diff --git a/request-attrs.go b/request-attrs.go index b5c95b4a..c86539cc 100644 --- a/request-attrs.go +++ b/request-attrs.go @@ -3,7 +3,6 @@ package sftp // Methods on the Request object to make working with the Flags bitmasks and // Attr(ibutes) byte blob easier. Use Pflags() when working with an Open/Write // request and AttrFlags() and Attributes() when working with SetStat requests. -import "os" // FileOpenFlags defines Open and Write Flags. Correlate directly with with os.OpenFile flags // (https://golang.org/pkg/os/#pkg-constants). @@ -50,11 +49,6 @@ func (r *Request) AttrFlags() FileAttrFlags { return newFileAttrFlags(r.Flags) } -// FileMode returns the Mode SFTP file attributes wrapped as os.FileMode -func (a FileStat) FileMode() os.FileMode { - return os.FileMode(a.Mode) -} - // Attributes parses file attributes byte blob and return them in a // FileStat object. func (r *Request) Attributes() *FileStat { diff --git a/server.go b/server.go index 2e419f59..6e53e264 100644 --- a/server.go +++ b/server.go @@ -13,7 +13,6 @@ import ( "strconv" "sync" "syscall" - "time" ) const ( @@ -462,7 +461,15 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { osFlags |= os.O_EXCL } - f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, 0o644) + mode := os.FileMode(0o644) + // Like OpenSSH, we only handle permissions here, if the file is being created. + // Otherwise, the permissions are ignored. + if p.Flags & sshFileXferAttrPermissions != 0 { + fs := p.unmarshalFileStat(p.Flags) + mode = fs.FileMode() & os.ModePerm + } + + f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode) if err != nil { return statusFromError(p.ID, err) } @@ -496,43 +503,32 @@ func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket { } func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { - // additional unmarshalling is required for each possibility here - b := p.Attrs.([]byte) - var err error + path := svr.toLocalPath(p.Path) - p.Path = svr.toLocalPath(p.Path) + debug("setstat name %q", path) + + fs := p.unmarshalFileStat(p.Flags) + + var err error - debug("setstat name \"%s\"", p.Path) if (p.Flags & sshFileXferAttrSize) != 0 { - var size uint64 - if size, b, err = unmarshalUint64Safe(b); err == nil { - err = os.Truncate(p.Path, int64(size)) + if err == nil { + err = os.Truncate(path, int64(fs.Size)) } } if (p.Flags & sshFileXferAttrPermissions) != 0 { - var mode uint32 - if mode, b, err = unmarshalUint32Safe(b); err == nil { - err = os.Chmod(p.Path, os.FileMode(mode)) + if err == nil { + err = os.Chmod(path, fs.FileMode()) } } if (p.Flags & sshFileXferAttrACmodTime) != 0 { - var atime uint32 - var mtime uint32 - if atime, b, err = unmarshalUint32Safe(b); err != nil { - } else if mtime, b, err = unmarshalUint32Safe(b); err != nil { - } else { - atimeT := time.Unix(int64(atime), 0) - mtimeT := time.Unix(int64(mtime), 0) - err = os.Chtimes(p.Path, atimeT, mtimeT) + if err == nil { + err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) } } if (p.Flags & sshFileXferAttrUIDGID) != 0 { - var uid uint32 - var gid uint32 - if uid, b, err = unmarshalUint32Safe(b); err != nil { - } else if gid, _, err = unmarshalUint32Safe(b); err != nil { - } else { - err = os.Chown(p.Path, int(uid), int(gid)) + if err == nil { + err = os.Chown(path, int(fs.UID), int(fs.GID)) } } @@ -545,41 +541,32 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { return statusFromError(p.ID, EBADF) } - // additional unmarshalling is required for each possibility here - b := p.Attrs.([]byte) + path := f.Name() + + debug("fsetstat name %q", path) + + fs := p.unmarshalFileStat(p.Flags) + var err error - debug("fsetstat name \"%s\"", f.Name()) if (p.Flags & sshFileXferAttrSize) != 0 { - var size uint64 - if size, b, err = unmarshalUint64Safe(b); err == nil { - err = f.Truncate(int64(size)) + if err == nil { + err = f.Truncate(int64(fs.Size)) } } if (p.Flags & sshFileXferAttrPermissions) != 0 { - var mode uint32 - if mode, b, err = unmarshalUint32Safe(b); err == nil { - err = f.Chmod(os.FileMode(mode)) + if err == nil { + err = f.Chmod(fs.FileMode()) } } if (p.Flags & sshFileXferAttrACmodTime) != 0 { - var atime uint32 - var mtime uint32 - if atime, b, err = unmarshalUint32Safe(b); err != nil { - } else if mtime, b, err = unmarshalUint32Safe(b); err != nil { - } else { - atimeT := time.Unix(int64(atime), 0) - mtimeT := time.Unix(int64(mtime), 0) - err = os.Chtimes(f.Name(), atimeT, mtimeT) + if err == nil { + err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) } } if (p.Flags & sshFileXferAttrUIDGID) != 0 { - var uid uint32 - var gid uint32 - if uid, b, err = unmarshalUint32Safe(b); err != nil { - } else if gid, _, err = unmarshalUint32Safe(b); err != nil { - } else { - err = f.Chown(int(uid), int(gid)) + if err == nil { + err = f.Chown(int(fs.UID), int(fs.GID)) } } diff --git a/server_integration_test.go b/server_integration_test.go index 407d38a2..74a6f8a1 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -591,18 +591,25 @@ ls -l /usr/bin/ goWords := spaceRegex.Split(goLine, -1) opWords := spaceRegex.Split(opLine, -1) // some fields are allowed to be different.. - // words[2] and [3] as these are users & groups - // words[1] as the link count for directories like proc is unstable // during testing as processes are created/destroyed. - // words[7] as timestamp on dirs can very for things like /tmp for j, goWord := range goWords { if j >= len(opWords) { bad = true break } opWord := opWords[j] - if goWord != opWord && j != 1 && j != 2 && j != 3 && j != 7 { - bad = true + if goWord != opWord { + switch j { + case 1, 2, 3, 7: + // words[1] as the link count for directories like proc is unstable + // words[2] and [3] as these are users & groups + // words[7] as timestamps on dirs can vary for things like /tmp + case 8: + // words[8] can either have full path or just the filename + bad = !strings.HasSuffix(opWord, "/" + goWord) + default: + bad = true + } } } } diff --git a/server_test.go b/server_test.go index 87beece5..110e0dee 100644 --- a/server_test.go +++ b/server_test.go @@ -178,21 +178,22 @@ func TestOpenStatRace(t *testing.T) { // openpacket finishes to fast to trigger race in tests // need to add a small sleep on server to openpackets somehow tmppath := path.Join(os.TempDir(), "stat_race") - pflags := flags(os.O_RDWR | os.O_CREATE | os.O_TRUNC) + pflags := toPflags(os.O_RDWR | os.O_CREATE | os.O_TRUNC) ch := make(chan result, 3) id1 := client.nextID() + id2 := client.nextID() client.dispatchRequest(ch, &sshFxpOpenPacket{ ID: id1, Path: tmppath, Pflags: pflags, }) - id2 := client.nextID() client.dispatchRequest(ch, &sshFxpLstatPacket{ ID: id2, Path: tmppath, }) testreply := func(id uint32) { r := <-ch + require.NoError(t, r.err) switch r.typ { case sshFxpAttrs, sshFxpHandle: // ignore case sshFxpStatus: @@ -208,6 +209,83 @@ func TestOpenStatRace(t *testing.T) { checkServerAllocator(t, server) } +func TestOpenWithPermissions(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + skipIfWindows(t) + + client, server := clientServerPair(t) + defer client.Close() + defer server.Close() + + tmppath := path.Join(os.TempDir(), "open_permissions") + defer os.Remove(tmppath) + + pflags := toPflags(os.O_RDWR | os.O_CREATE | os.O_TRUNC) + + id1 := client.nextID() + id2 := client.nextID() + + typ, data, err := client.sendPacket(ctx, nil, &sshFxpOpenPacket{ + ID: id1, + Path: tmppath, + Pflags: pflags, + Flags: sshFileXferAttrPermissions, + Attrs: &FileStat{ + Mode: 0o745, + }, + }) + if err != nil { + t.Fatal("unexpected error:", err) + } + switch typ { + case sshFxpHandle: + // do nothing, we can just leave the handle open. + case sshFxpStatus: + t.Fatal("unexpected status:", normaliseError(unmarshalStatus(id1, data))) + default: + t.Fatal("unpexpected packet type:", unimplementedPacketErr(typ)) + } + + stat, err := os.Stat(tmppath) + if err != nil { + t.Fatal("unexpected error:", err) + } + + if stat.Mode()&os.ModePerm != 0o745 { + t.Errorf("stat.Mode() = %v was expecting 0o745", stat.Mode()) + } + + // Existing files should not have their permissions changed. + typ, data, err = client.sendPacket(ctx, nil, &sshFxpOpenPacket{ + ID: id2, + Path: tmppath, + Pflags: pflags, + Flags: sshFileXferAttrPermissions, + Attrs: &FileStat{ + Mode: 0o755, + }, + }) + if err != nil { + t.Fatal("unexpected error:", err) + } + switch typ { + case sshFxpHandle: + // do nothing, we can just leave the handle open. + case sshFxpStatus: + t.Fatal("unexpected status:", normaliseError(unmarshalStatus(id2, data))) + default: + t.Fatal("unpexpected packet type:", unimplementedPacketErr(typ)) + } + + if stat.Mode()&os.ModePerm != 0o745 { + t.Errorf("stat.Mode() = %v, was expecting unchanged 0o745", stat.Mode()) + } + + checkServerAllocator(t, server) +} + // Ensure that proper error codes are returned for non existent files, such // that they are mapped back to a 'not exists' error on the client side. func TestStatNonExistent(t *testing.T) { From f3501dc6ba301548dc514108039e66f319748a1a Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 19 Jan 2024 01:23:22 +0000 Subject: [PATCH 02/12] address code review --- client.go | 30 ++++++++----- packet.go | 89 ++++++++++++++++++++++++++------------ packet_test.go | 11 +++-- request-attrs.go | 2 +- request-attrs_test.go | 18 ++++++-- server.go | 28 +++++++----- server_integration_test.go | 2 +- server_test.go | 12 ++--- 8 files changed, 127 insertions(+), 65 deletions(-) diff --git a/client.go b/client.go index 1d55aaea..12d105ad 100644 --- a/client.go +++ b/client.go @@ -363,7 +363,10 @@ func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, e filename, data = unmarshalString(data) _, data = unmarshalString(data) // discard longname var attr *FileStat - attr, data = unmarshalAttrs(data) + attr, data, err = unmarshalAttrs(data) + if err != nil { + return nil, err + } if filename == "." || filename == ".." { continue } @@ -434,8 +437,8 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) { if sid != id { return nil, &unexpectedIDErr{id, sid} } - attr, _ := unmarshalAttrs(data) - return fileInfoFromStat(attr, path.Base(p)), nil + attr, _, err := unmarshalAttrs(data) + return fileInfoFromStat(attr, path.Base(p)), err case sshFxpStatus: return nil, normaliseError(unmarshalStatus(id, data)) default: @@ -660,8 +663,8 @@ func (c *Client) stat(path string) (*FileStat, error) { if sid != id { return nil, &unexpectedIDErr{id, sid} } - attr, _ := unmarshalAttrs(data) - return attr, nil + attr, _, err := unmarshalAttrs(data) + return attr, err case sshFxpStatus: return nil, normaliseError(unmarshalStatus(id, data)) default: @@ -684,8 +687,8 @@ func (c *Client) fstat(handle string) (*FileStat, error) { if sid != id { return nil, &unexpectedIDErr{id, sid} } - attr, _ := unmarshalAttrs(data) - return attr, nil + attr, _, err := unmarshalAttrs(data) + return attr, err case sshFxpStatus: return nil, normaliseError(unmarshalStatus(id, data)) default: @@ -974,8 +977,8 @@ func (c *Client) RemoveAll(path string) error { // File represents a remote file. type File struct { - c *Client - path string + c *Client + path string mu sync.RWMutex handle string @@ -992,6 +995,10 @@ func (f *File) Close() error { return os.ErrClosed } + // When `openssh-portable/sftp-server.c` is doing `handle_close`, + // it will unconditionally mark the handle as unused, + // so we need to also unconditionally mark this handle as invalid. + handle := f.handle f.handle = "" @@ -1485,6 +1492,8 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { } } +// Stat returns the FileInfo structure describing file. If there is an +// error. func (f *File) Stat() (os.FileInfo, error) { f.mu.RLock() defer f.mu.RUnlock() @@ -1496,8 +1505,6 @@ func (f *File) Stat() (os.FileInfo, error) { return f.stat() } -// Stat returns the FileInfo structure describing file. If there is an -// error. func (f *File) stat() (os.FileInfo, error) { fs, err := f.c.fstat(f.handle) if err != nil { @@ -2055,7 +2062,6 @@ func (f *File) Sync() error { return os.ErrClosed } - id := f.c.nextID() typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{ ID: id, diff --git a/packet.go b/packet.go index 2fea2bef..f37cd4dc 100644 --- a/packet.go +++ b/packet.go @@ -174,36 +174,69 @@ func unmarshalStringSafe(b []byte) (string, []byte, error) { return string(b[:n]), b[n:], nil } -func unmarshalAttrs(b []byte) (*FileStat, []byte) { - flags, b := unmarshalUint32(b) +func unmarshalAttrs(b []byte) (*FileStat, []byte, error) { + flags, b, err := unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } return unmarshalFileStat(flags, b) } -func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) { +func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte, error) { var fs FileStat + var err error + if flags&sshFileXferAttrSize == sshFileXferAttrSize { - fs.Size, b, _ = unmarshalUint64Safe(b) + fs.Size, b, err = unmarshalUint64Safe(b) + if err != nil { + return nil, b, err + } } if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { - fs.UID, b, _ = unmarshalUint32Safe(b) - fs.GID, b, _ = unmarshalUint32Safe(b) + fs.UID, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } + fs.GID, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } } if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions { - fs.Mode, b, _ = unmarshalUint32Safe(b) + fs.Mode, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } } if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime { - fs.Atime, b, _ = unmarshalUint32Safe(b) - fs.Mtime, b, _ = unmarshalUint32Safe(b) + fs.Atime, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } + fs.Mtime, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } } if flags&sshFileXferAttrExtended == sshFileXferAttrExtended { var count uint32 - count, b, _ = unmarshalUint32Safe(b) + count, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } + ext := make([]StatExtended, count) for i := uint32(0); i < count; i++ { var typ string var data string - typ, b, _ = unmarshalStringSafe(b) - data, b, _ = unmarshalStringSafe(b) + typ, b, err = unmarshalStringSafe(b) + if err != nil { + return nil, b, err + } + data, b, err = unmarshalStringSafe(b) + if err != nil { + return nil, b, err + } ext[i] = StatExtended{ ExtType: typ, ExtData: data, @@ -211,7 +244,7 @@ func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) { } fs.Extended = ext } - return &fs, b + return &fs, b, nil } func unmarshalStatus(id uint32, data []byte) error { @@ -734,15 +767,15 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { return nil } -func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) *FileStat { +func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) (*FileStat, error) { switch attrs := p.Attrs.(type) { case *FileStat: - return attrs + return attrs, nil case []byte: - fs, _ := unmarshalFileStat(flags, attrs) - return fs + fs, _, err := unmarshalFileStat(flags, attrs) + return fs, err default: - panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs)) + return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs) } } @@ -1030,15 +1063,15 @@ func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error { return nil } -func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) *FileStat { +func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) { switch attrs := p.Attrs.(type) { case *FileStat: - return attrs + return attrs, nil case []byte: - fs, _ := unmarshalFileStat(flags, attrs) - return fs + fs, _, err := unmarshalFileStat(flags, attrs) + return fs, err default: - panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs)) + return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs) } } @@ -1055,15 +1088,15 @@ func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error { return nil } -func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) *FileStat { +func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) { switch attrs := p.Attrs.(type) { case *FileStat: - return attrs + return attrs, nil case []byte: - fs, _ := unmarshalFileStat(flags, attrs) - return fs + fs, _, err := unmarshalFileStat(flags, attrs) + return fs, err default: - panic(fmt.Sprintf("invalid type in unmarshalFileStat: %T", attrs)) + return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs) } } diff --git a/packet_test.go b/packet_test.go index 6278ca4f..98455abe 100644 --- a/packet_test.go +++ b/packet_test.go @@ -284,7 +284,10 @@ func TestUnmarshalAttrs(t *testing.T) { } for _, tt := range tests { - got, _ := unmarshalAttrs(tt.b) + got, _, err := unmarshalAttrs(tt.b) + if err != nil { + t.Fatal("unexpected error:", err) + } if !reflect.DeepEqual(got, tt.want) { t.Errorf("unmarshalAttrs(% X):\n- got: %#v\n- want: %#v", tt.b, got, tt.want) } @@ -389,11 +392,11 @@ func TestSendPacket(t *testing.T) { }, { packet: &sshFxpOpenPacket{ - ID: 3, - Path: "/foo", + ID: 3, + Path: "/foo", Pflags: toPflags(os.O_WRONLY | os.O_CREATE | os.O_TRUNC), Flags: sshFileXferAttrPermissions, - Attrs: &FileStat{ + Attrs: &FileStat{ Mode: 0o755, }, }, diff --git a/request-attrs.go b/request-attrs.go index c86539cc..476c5651 100644 --- a/request-attrs.go +++ b/request-attrs.go @@ -52,6 +52,6 @@ func (r *Request) AttrFlags() FileAttrFlags { // Attributes parses file attributes byte blob and return them in a // FileStat object. func (r *Request) Attributes() *FileStat { - fs, _ := unmarshalFileStat(r.Flags, r.Attrs) + fs, _, _ := unmarshalFileStat(r.Flags, r.Attrs) return fs } diff --git a/request-attrs_test.go b/request-attrs_test.go index 658afca0..b1b559b8 100644 --- a/request-attrs_test.go +++ b/request-attrs_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRequestPflags(t *testing.T) { @@ -33,7 +34,8 @@ func TestRequestAttributes(t *testing.T) { at := []byte{} at = marshalUint32(at, 1) at = marshalUint32(at, 2) - testFs, _ := unmarshalFileStat(fl, at) + testFs, _, err := unmarshalFileStat(fl, at) + require.NoError(t, err) assert.Equal(t, fa, *testFs) // Size and Mode fa = FileStat{Mode: 0700, Size: 99} @@ -41,7 +43,8 @@ func TestRequestAttributes(t *testing.T) { at = []byte{} at = marshalUint64(at, 99) at = marshalUint32(at, 0700) - testFs, _ = unmarshalFileStat(fl, at) + testFs, _, err = unmarshalFileStat(fl, at) + require.NoError(t, err) assert.Equal(t, fa, *testFs) // FileMode assert.True(t, testFs.FileMode().IsRegular()) @@ -50,7 +53,16 @@ func TestRequestAttributes(t *testing.T) { } func TestRequestAttributesEmpty(t *testing.T) { - fs, b := unmarshalFileStat(sshFileXferAttrAll, nil) + fs, b, err := unmarshalFileStat(sshFileXferAttrAll, []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // size + 0x00, 0x00, 0x00, 0x00, // mode + 0x00, 0x00, 0x00, 0x00, // mtime + 0x00, 0x00, 0x00, 0x00, // atime + 0x00, 0x00, 0x00, 0x00, // uid + 0x00, 0x00, 0x00, 0x00, // gid + 0x00, 0x00, 0x00, 0x00, // extended_count + }) + require.NoError(t, err) assert.Equal(t, &FileStat{ Extended: []StatExtended{}, }, fs) diff --git a/server.go b/server.go index 6e53e264..16f1cabc 100644 --- a/server.go +++ b/server.go @@ -13,6 +13,7 @@ import ( "strconv" "sync" "syscall" + "time" ) const ( @@ -462,10 +463,13 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { } mode := os.FileMode(0o644) - // Like OpenSSH, we only handle permissions here, if the file is being created. + // Like OpenSSH, we only handle permissions here, and only when the file is being created. // Otherwise, the permissions are ignored. - if p.Flags & sshFileXferAttrPermissions != 0 { - fs := p.unmarshalFileStat(p.Flags) + if p.Flags&sshFileXferAttrPermissions != 0 { + fs, err := p.unmarshalFileStat(p.Flags) + if err != nil { + return statusFromError(p.ID, err) + } mode = fs.FileMode() & os.ModePerm } @@ -507,9 +511,7 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { debug("setstat name %q", path) - fs := p.unmarshalFileStat(p.Flags) - - var err error + fs, err := p.unmarshalFileStat(p.Flags) if (p.Flags & sshFileXferAttrSize) != 0 { if err == nil { @@ -545,9 +547,7 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { debug("fsetstat name %q", path) - fs := p.unmarshalFileStat(p.Flags) - - var err error + fs, err := p.unmarshalFileStat(p.Flags) if (p.Flags & sshFileXferAttrSize) != 0 { if err == nil { @@ -561,7 +561,15 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { } if (p.Flags & sshFileXferAttrACmodTime) != 0 { if err == nil { - err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) + switch f := interface{}(f).(type) { + case interface { + Chtimes(atime, mtime time.Time) error + }: + // future-compatible, if any when *os.File supports Chtimes. + err = f.Chtimes(fs.AccessTime(), fs.ModTime()) + default: + err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) + } } } if (p.Flags & sshFileXferAttrUIDGID) != 0 { diff --git a/server_integration_test.go b/server_integration_test.go index 74a6f8a1..398ea865 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -606,7 +606,7 @@ ls -l /usr/bin/ // words[7] as timestamps on dirs can vary for things like /tmp case 8: // words[8] can either have full path or just the filename - bad = !strings.HasSuffix(opWord, "/" + goWord) + bad = !strings.HasSuffix(opWord, "/"+goWord) default: bad = true } diff --git a/server_test.go b/server_test.go index 110e0dee..4cec3123 100644 --- a/server_test.go +++ b/server_test.go @@ -228,11 +228,11 @@ func TestOpenWithPermissions(t *testing.T) { id2 := client.nextID() typ, data, err := client.sendPacket(ctx, nil, &sshFxpOpenPacket{ - ID: id1, - Path: tmppath, + ID: id1, + Path: tmppath, Pflags: pflags, Flags: sshFileXferAttrPermissions, - Attrs: &FileStat{ + Attrs: &FileStat{ Mode: 0o745, }, }) @@ -259,11 +259,11 @@ func TestOpenWithPermissions(t *testing.T) { // Existing files should not have their permissions changed. typ, data, err = client.sendPacket(ctx, nil, &sshFxpOpenPacket{ - ID: id2, - Path: tmppath, + ID: id2, + Path: tmppath, Pflags: pflags, Flags: sshFileXferAttrPermissions, - Attrs: &FileStat{ + Attrs: &FileStat{ Mode: 0o755, }, }) From e21cd9480548a5fb15ae9e399af85c7b43b2784e Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 19 Jan 2024 01:27:16 +0000 Subject: [PATCH 03/12] move setting times to the last operation so chown doesn't have a chance to alter atime or mtime --- server.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/server.go b/server.go index 16f1cabc..bddcf77e 100644 --- a/server.go +++ b/server.go @@ -523,14 +523,14 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { err = os.Chmod(path, fs.FileMode()) } } - if (p.Flags & sshFileXferAttrACmodTime) != 0 { + if (p.Flags & sshFileXferAttrUIDGID) != 0 { if err == nil { - err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) + err = os.Chown(path, int(fs.UID), int(fs.GID)) } } - if (p.Flags & sshFileXferAttrUIDGID) != 0 { + if (p.Flags & sshFileXferAttrACmodTime) != 0 { if err == nil { - err = os.Chown(path, int(fs.UID), int(fs.GID)) + err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) } } @@ -559,6 +559,11 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { err = f.Chmod(fs.FileMode()) } } + if (p.Flags & sshFileXferAttrUIDGID) != 0 { + if err == nil { + err = f.Chown(int(fs.UID), int(fs.GID)) + } + } if (p.Flags & sshFileXferAttrACmodTime) != 0 { if err == nil { switch f := interface{}(f).(type) { @@ -572,11 +577,6 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { } } } - if (p.Flags & sshFileXferAttrUIDGID) != 0 { - if err == nil { - err = f.Chown(int(fs.UID), int(fs.GID)) - } - } return statusFromError(p.ID, err) } From 3df3035b74829370897621dbd02abbdb3b12a4fd Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Fri, 19 Jan 2024 01:56:44 +0000 Subject: [PATCH 04/12] new race condition warning, yay --- client.go | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index 12d105ad..434b2cb1 100644 --- a/client.go +++ b/client.go @@ -1023,10 +1023,6 @@ func (f *File) Read(b []byte) (int, error) { f.mu.Lock() defer f.mu.Unlock() - if f.handle == "" { - return 0, os.ErrClosed - } - n, err := f.readAt(b, f.offset) f.offset += int64(n) return n, err @@ -1095,14 +1091,15 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) { f.mu.RLock() defer f.mu.RUnlock() - if f.handle == "" { - return 0, os.ErrClosed - } - return f.readAt(b, off) } func (f *File) readAt(b []byte, off int64) (int, error) { + if f.handle == "" { + return 0, os.ErrClosed + } + handle := f.handle // need a local copy to prevent aberrent race detection + if len(b) <= f.c.maxPacket { // This should be able to be serviced with 1/2 requests. // So, just do it directly. @@ -1154,7 +1151,7 @@ func (f *File) readAt(b []byte, off int64) (int, error) { f.c.dispatchRequest(res, &sshFxpReadPacket{ ID: id, - Handle: f.handle, + Handle: handle, Offset: uint64(offset), Len: uint32(chunkSize), }) @@ -1302,6 +1299,7 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { if f.handle == "" { return 0, os.ErrClosed } + handle := f.handle // need a local copy to prevent aberrent race detection if f.c.disableConcurrentReads { return f.writeToSequential(w) @@ -1387,7 +1385,7 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { f.c.dispatchRequest(res, &sshFxpReadPacket{ ID: id, - Handle: f.handle, + Handle: handle, Offset: uint64(off), Len: uint32(chunkSize), }) @@ -1740,14 +1738,15 @@ func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64 f.mu.Lock() defer f.mu.Unlock() - if f.handle == "" { - return 0, os.ErrClosed - } - return f.readFromWithConcurrency(r, concurrency) } func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { + if f.handle == "" { + return 0, os.ErrClosed + } + handle := f.handle // need a local copy to prevent aberrent race detection + // Split the write into multiple maxPacket sized concurrent writes. // This allows writes with a suitably large reader // to transfer data at a much faster rate due to overlapping round trip times. @@ -1792,7 +1791,7 @@ func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64 f.c.dispatchRequest(res, &sshFxpWritePacket{ ID: id, - Handle: f.handle, + Handle: handle, Offset: uint64(off), Length: uint32(n), Data: b[:n], From 4cd7ff45fc98c45f8afe890edf3eafb0edbc0ef7 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 6 Feb 2024 08:12:54 +0000 Subject: [PATCH 05/12] testing an idea --- client.go | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index 434b2cb1..74ed04b4 100644 --- a/client.go +++ b/client.go @@ -438,7 +438,10 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) { return nil, &unexpectedIDErr{id, sid} } attr, _, err := unmarshalAttrs(data) - return fileInfoFromStat(attr, path.Base(p)), err + if err != nil { + return nil, err + } + return fileInfoFromStat(attr, path.Base(p)), nil case sshFxpStatus: return nil, normaliseError(unmarshalStatus(id, data)) default: @@ -664,6 +667,9 @@ func (c *Client) stat(path string) (*FileStat, error) { return nil, &unexpectedIDErr{id, sid} } attr, _, err := unmarshalAttrs(data) + if err != nil { + return nil, err + } return attr, err case sshFxpStatus: return nil, normaliseError(unmarshalStatus(id, data)) @@ -1094,11 +1100,26 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) { return f.readAt(b, off) } +// readAt must be called while holding either the Read or Write mutex in File. +// This code is concurrent safe with itself, but not with Close. func (f *File) readAt(b []byte, off int64) (int, error) { if f.handle == "" { return 0, os.ErrClosed } - handle := f.handle // need a local copy to prevent aberrent race detection + + // We need to make a local copy of this handle in order to prevent aberrent race detection. + // Because the value is referenced in a sub-goroutine while holding a RWMutex, + // the race detector flags that use as a race-condition. + // By bringing this into a local variable, we are not actually resolving the concurrency issue, + // since the sub-goroutine could erroneously be using a stale handle value after a Close. + // + // However, this use is not erroneous, so long as this function cannot return before that goroutine ends. + // This is guaranteed because: + // 1. This main goroutine returns strictly after `errCh` closes. + // 2. A goroutine is using `wg.Wait` to close `errCh`. + // 3. The goroutines consuming `workCh` call `wg.Done` in a defer, and returns only when `workCh` is closed.. + // 4. The goroutine referencing `handle` closes `workCh` in a defer. + // TODO: handle := f.handle if len(b) <= f.c.maxPacket { // This should be able to be serviced with 1/2 requests. @@ -1151,7 +1172,7 @@ func (f *File) readAt(b []byte, off int64) (int, error) { f.c.dispatchRequest(res, &sshFxpReadPacket{ ID: id, - Handle: handle, + Handle: f.handle, Offset: uint64(offset), Len: uint32(chunkSize), }) @@ -1217,7 +1238,9 @@ func (f *File) readAt(b []byte, off int64) (int, error) { if err != nil { // return the offset as the start + how much we read before the error. errCh <- rErr{packet.off + int64(n), err} - return + + // DO NOT return. + // We want to ensure that workCh is drained before wg.Wait returns. } } }() @@ -1695,6 +1718,8 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) { return f.writeAt(b, off) } +// writeAt must be called while holding either the Read or Write mutex in File. +// This code is concurrent safe with itself, but not with Close. func (f *File) writeAt(b []byte, off int64) (written int, err error) { if len(b) <= f.c.maxPacket { // We can do this in one write. From 6c7c0da80c25b69f6ce1c7b1e933ca4b08cf73ff Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 6 Feb 2024 08:27:28 +0000 Subject: [PATCH 06/12] remove warnings about aberrent race detection, I think it was real --- client.go | 28 +++++++--------------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/client.go b/client.go index 74ed04b4..ee372f07 100644 --- a/client.go +++ b/client.go @@ -1107,20 +1107,6 @@ func (f *File) readAt(b []byte, off int64) (int, error) { return 0, os.ErrClosed } - // We need to make a local copy of this handle in order to prevent aberrent race detection. - // Because the value is referenced in a sub-goroutine while holding a RWMutex, - // the race detector flags that use as a race-condition. - // By bringing this into a local variable, we are not actually resolving the concurrency issue, - // since the sub-goroutine could erroneously be using a stale handle value after a Close. - // - // However, this use is not erroneous, so long as this function cannot return before that goroutine ends. - // This is guaranteed because: - // 1. This main goroutine returns strictly after `errCh` closes. - // 2. A goroutine is using `wg.Wait` to close `errCh`. - // 3. The goroutines consuming `workCh` call `wg.Done` in a defer, and returns only when `workCh` is closed.. - // 4. The goroutine referencing `handle` closes `workCh` in a defer. - // TODO: handle := f.handle - if len(b) <= f.c.maxPacket { // This should be able to be serviced with 1/2 requests. // So, just do it directly. @@ -1322,7 +1308,6 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { if f.handle == "" { return 0, os.ErrClosed } - handle := f.handle // need a local copy to prevent aberrent race detection if f.c.disableConcurrentReads { return f.writeToSequential(w) @@ -1408,7 +1393,7 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { f.c.dispatchRequest(res, &sshFxpReadPacket{ ID: id, - Handle: handle, + Handle: f.handle, Offset: uint64(off), Len: uint32(chunkSize), }) @@ -1474,9 +1459,8 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { return } - if err != nil { - return - } + // DO NOT return. + // We want to ensure that readCh is drained before wg.Wait returns. } }() } @@ -1770,7 +1754,6 @@ func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64 if f.handle == "" { return 0, os.ErrClosed } - handle := f.handle // need a local copy to prevent aberrent race detection // Split the write into multiple maxPacket sized concurrent writes. // This allows writes with a suitably large reader @@ -1816,7 +1799,7 @@ func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64 f.c.dispatchRequest(res, &sshFxpWritePacket{ ID: id, - Handle: handle, + Handle: f.handle, Offset: uint64(off), Length: uint32(n), Data: b[:n], @@ -1863,6 +1846,9 @@ func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64 if err != nil { errCh <- rwErr{work.off, err} + + // DO NOT return. + // We want to ensure that workCh is drained before wg.Wait returns. } } }() From e808920da05ae22b46a98339158a641c2feb5bcd Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 6 Feb 2024 08:31:27 +0000 Subject: [PATCH 07/12] remove unnecessary block, and explain why the one added is necessary --- client.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/client.go b/client.go index ee372f07..b30159b4 100644 --- a/client.go +++ b/client.go @@ -439,6 +439,7 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) { } attr, _, err := unmarshalAttrs(data) if err != nil { + // avoid returning a valid value from fileInfoFromStats if err != nil. return nil, err } return fileInfoFromStat(attr, path.Base(p)), nil @@ -667,9 +668,6 @@ func (c *Client) stat(path string) (*FileStat, error) { return nil, &unexpectedIDErr{id, sid} } attr, _, err := unmarshalAttrs(data) - if err != nil { - return nil, err - } return attr, err case sshFxpStatus: return nil, normaliseError(unmarshalStatus(id, data)) From ba3d6ab7c6a0b5c8e42f307e0f2095b403eb78e8 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 6 Feb 2024 08:34:29 +0000 Subject: [PATCH 08/12] explain mechanics of use-after-close protection --- client.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index b30159b4..5bc5df9b 100644 --- a/client.go +++ b/client.go @@ -999,9 +999,11 @@ func (f *File) Close() error { return os.ErrClosed } - // When `openssh-portable/sftp-server.c` is doing `handle_close`, + // The design principle here is that when `openssh-portable/sftp-server.c` is doing `handle_close`, // it will unconditionally mark the handle as unused, // so we need to also unconditionally mark this handle as invalid. + // By invalidating our local copy of the handle, + // we ensure that there cannot be any erroneous use-after-close requests sent after Close. handle := f.handle f.handle = "" From 72aa4039a11c3525d4e3f6831d897d82678ad4a5 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 6 Feb 2024 08:41:41 +0000 Subject: [PATCH 09/12] more short-circuits --- packet.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/packet.go b/packet.go index f37cd4dc..cbaa90e7 100644 --- a/packet.go +++ b/packet.go @@ -737,6 +737,8 @@ func (p *sshFxpOpenPacket) marshalPacket() ([]byte, []byte, error) { b = marshalUint32(b, p.Flags) switch attrs := p.Attrs.(type) { + case []byte: + return b, attrs, nil // may as well short-ciruit this case. case os.FileInfo: _, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet. return b, marshalFileStat(nil, p.Flags, fs), nil @@ -1008,6 +1010,8 @@ func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) { b = marshalUint32(b, p.Flags) switch attrs := p.Attrs.(type) { + case []byte: + return b, attrs, nil // may as well short-ciruit this case. case os.FileInfo: _, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet. return b, marshalFileStat(nil, p.Flags, fs), nil @@ -1035,6 +1039,8 @@ func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) { b = marshalUint32(b, p.Flags) switch attrs := p.Attrs.(type) { + case []byte: + return b, attrs, nil // may as well short-ciruit this case. case os.FileInfo: _, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet. return b, marshalFileStat(nil, p.Flags, fs), nil From 3ce4d4e6e2bfb842cc7a6d075d47f70b9f2b0a5c Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 6 Feb 2024 08:51:12 +0000 Subject: [PATCH 10/12] one more race-condition causing return --- client.go | 1 - 1 file changed, 1 deletion(-) diff --git a/client.go b/client.go index 5bc5df9b..5f457221 100644 --- a/client.go +++ b/client.go @@ -1456,7 +1456,6 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { select { case readWork.cur <- writeWork: case <-cancel: - return } // DO NOT return. From 5d66cdeb9ac85f2c69eba2815f1de21d984974d5 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 6 Feb 2024 15:02:43 +0000 Subject: [PATCH 11/12] better cascading --- server.go | 64 ++++++++++++++++++++++--------------------------------- 1 file changed, 25 insertions(+), 39 deletions(-) diff --git a/server.go b/server.go index bddcf77e..acdc30ed 100644 --- a/server.go +++ b/server.go @@ -513,25 +513,17 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { fs, err := p.unmarshalFileStat(p.Flags) - if (p.Flags & sshFileXferAttrSize) != 0 { - if err == nil { - err = os.Truncate(path, int64(fs.Size)) - } + if err == nil && (p.Flags & sshFileXferAttrSize) != 0 { + err = os.Truncate(path, int64(fs.Size)) } - if (p.Flags & sshFileXferAttrPermissions) != 0 { - if err == nil { - err = os.Chmod(path, fs.FileMode()) - } + if err == nil && (p.Flags & sshFileXferAttrPermissions) != 0 { + err = os.Chmod(path, fs.FileMode()) } - if (p.Flags & sshFileXferAttrUIDGID) != 0 { - if err == nil { - err = os.Chown(path, int(fs.UID), int(fs.GID)) - } + if err == nil && (p.Flags & sshFileXferAttrUIDGID) != 0 { + err = os.Chown(path, int(fs.UID), int(fs.GID)) } - if (p.Flags & sshFileXferAttrACmodTime) != 0 { - if err == nil { - err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) - } + if err == nil && (p.Flags & sshFileXferAttrACmodTime) != 0 { + err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) } return statusFromError(p.ID, err) @@ -549,32 +541,26 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { fs, err := p.unmarshalFileStat(p.Flags) - if (p.Flags & sshFileXferAttrSize) != 0 { - if err == nil { - err = f.Truncate(int64(fs.Size)) - } + if err == nil && (p.Flags & sshFileXferAttrSize) != 0 { + err = f.Truncate(int64(fs.Size)) } - if (p.Flags & sshFileXferAttrPermissions) != 0 { - if err == nil { - err = f.Chmod(fs.FileMode()) - } + if err == nil && (p.Flags & sshFileXferAttrPermissions) != 0 { + err = f.Chmod(fs.FileMode()) } - if (p.Flags & sshFileXferAttrUIDGID) != 0 { - if err == nil { - err = f.Chown(int(fs.UID), int(fs.GID)) - } + if err == nil && (p.Flags & sshFileXferAttrUIDGID) != 0 { + err = f.Chown(int(fs.UID), int(fs.GID)) } - if (p.Flags & sshFileXferAttrACmodTime) != 0 { - if err == nil { - switch f := interface{}(f).(type) { - case interface { - Chtimes(atime, mtime time.Time) error - }: - // future-compatible, if any when *os.File supports Chtimes. - err = f.Chtimes(fs.AccessTime(), fs.ModTime()) - default: - err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) - } + if err == nil && (p.Flags & sshFileXferAttrACmodTime) != 0 { + type chtimer interface { + Chtimes(atime, mtime time.Time) error + } + + switch f := interface{}(f).(type) { + case chtimer: + // future-compatible, for when/if *os.File supports Chtimes. + err = f.Chtimes(fs.AccessTime(), fs.ModTime()) + default: + err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) } } From 159d28655bcb721b8f8b973477cccb2a1c3923ec Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Mon, 12 Feb 2024 07:44:23 +0000 Subject: [PATCH 12/12] populate Attrs in requestFromPacket --- request.go | 1 + 1 file changed, 1 insertion(+) diff --git a/request.go b/request.go index 57d788df..266abc00 100644 --- a/request.go +++ b/request.go @@ -178,6 +178,7 @@ func requestFromPacket(ctx context.Context, pkt hasPath, baseDir string) *Reques switch p := pkt.(type) { case *sshFxpOpenPacket: request.Flags = p.Pflags + request.Attrs = p.Attrs.([]byte) case *sshFxpSetstatPacket: request.Flags = p.Flags request.Attrs = p.Attrs.([]byte)