Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix Windows non-blocking I/O #1555

Merged
merged 14 commits into from
Mar 24, 2023
Merged
29 changes: 29 additions & 0 deletions internal/nbconn/nbconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ type NetConn struct {

writeDeadlineLock sync.Mutex
writeDeadline time.Time

// nbOperCnt Tracks how many operations performing simultaneously
nbOperCnt int
// nbOperMu Used to prevent concurrent SetBlockingMode calls
nbOperMu sync.Mutex
}

func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
Expand Down Expand Up @@ -157,6 +162,14 @@ func (c *NetConn) Read(b []byte) (n int, err error) {

var readN int
if readNonblocking {
if setSockModeErr := c.SetBlockingMode(false); setSockModeErr != nil {
return n, setSockModeErr
}

defer func() {
_ = c.SetBlockingMode(true)
}()

readN, err = c.nonblockingRead(b[n:])
} else {
readN, err = c.conn.Read(b[n:])
Expand Down Expand Up @@ -281,6 +294,14 @@ func (c *NetConn) flush() error {
var stopChan chan struct{}
var errChan chan error

if err := c.SetBlockingMode(false); err != nil {
return err
}

defer func() {
_ = c.SetBlockingMode(true)
}()

defer func() {
if stopChan != nil {
select {
Expand Down Expand Up @@ -324,6 +345,14 @@ func (c *NetConn) flush() error {
}

func (c *NetConn) BufferReadUntilBlock() error {
if err := c.SetBlockingMode(false); err != nil {
return err
}

defer func() {
_ = c.SetBlockingMode(true)
}()

for {
buf := iobufpool.Get(8 * 1024)
n, err := c.nonblockingRead(*buf)
Expand Down
5 changes: 5 additions & 0 deletions internal/nbconn/nbconn_real_non_block.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,8 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {

return n, nil
}

func (c *NetConn) SetBlockingMode(blocking bool) error {
// Do nothing on UNIX systems
return nil
}
126 changes: 102 additions & 24 deletions internal/nbconn/nbconn_real_non_block_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ package nbconn

import (
"errors"
"fmt"
"golang.org/x/sys/windows"
"io"
"syscall"
"time"
"unsafe"
)

Expand All @@ -21,6 +23,8 @@ var dll = syscall.MustLoadDLL("ws2_32.dll")
// );
var ioctlsocket = dll.MustFindProc("ioctlsocket")

var deadlineExpErr = errors.New("i/o timeout")

type sockMode int

const (
Expand All @@ -39,36 +43,48 @@ func setSockMode(fd uintptr, mode sockMode) error {
return nil
}

func (c *NetConn) isDeadlineSet(dl time.Time) bool {
return !dl.IsZero() && !dl.Equal(NonBlockingDeadline) && !dl.Equal(disableSetDeadlineDeadline)
}

func (c *NetConn) isWriteDeadlineExpired() bool {
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()

return c.isDeadlineSet(c.writeDeadline) && !time.Now().Before(c.writeDeadline)
}

func (c *NetConn) isReadDeadlineExpired() bool {
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()

return c.isDeadlineSet(c.readDeadline) && !time.Now().Before(c.readDeadline)
}

// realNonblockingWrite does a non-blocking write. readFlushLock must already be held.
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
if c.nonblockWriteFunc == nil {
c.nonblockWriteFunc = func(fd uintptr) (done bool) {
// Make sock non-blocking
if err := setSockMode(fd, sockModeNonBlocking); err != nil {
c.nonblockWriteErr = err
return true
}

var written uint32
var buf syscall.WSABuf
buf.Buf = &c.nonblockWriteBuf[0]
buf.Len = uint32(len(c.nonblockWriteBuf))
c.nonblockWriteErr = syscall.WSASend(syscall.Handle(fd), &buf, 1, &written, 0, nil, nil)
c.nonblockWriteN = int(written)

// Make sock blocking again
if err := setSockMode(fd, sockModeBlocking); err != nil {
c.nonblockWriteErr = err
return true
}

return true
}
}
c.nonblockWriteBuf = b
c.nonblockWriteN = 0
c.nonblockWriteErr = nil

if c.isWriteDeadlineExpired() {
c.nonblockWriteErr = deadlineExpErr

return 0, c.nonblockWriteErr
}

err = c.rawConn.Write(c.nonblockWriteFunc)
n = c.nonblockWriteN
c.nonblockWriteBuf = nil // ensure that no reference to b is kept.
Expand All @@ -94,12 +110,6 @@ func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
if c.nonblockReadFunc == nil {
c.nonblockReadFunc = func(fd uintptr) (done bool) {
// Make sock non-blocking
if err := setSockMode(fd, sockModeNonBlocking); err != nil {
c.nonblockReadErr = err
return true
}

var read uint32
var flags uint32
var buf syscall.WSABuf
Expand All @@ -108,19 +118,19 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
c.nonblockReadErr = syscall.WSARecv(syscall.Handle(fd), &buf, 1, &read, &flags, nil, nil)
c.nonblockReadN = int(read)

// Make sock blocking again
if err := setSockMode(fd, sockModeBlocking); err != nil {
c.nonblockReadErr = err
return true
}

return true
}
}
c.nonblockReadBuf = b
c.nonblockReadN = 0
c.nonblockReadErr = nil

if c.isReadDeadlineExpired() {
c.nonblockReadErr = deadlineExpErr

return 0, c.nonblockReadErr
}

err = c.rawConn.Read(c.nonblockReadFunc)
n = c.nonblockReadN
c.nonblockReadBuf = nil // ensure that no reference to b is kept.
Expand All @@ -147,3 +157,71 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {

return n, nil
}

func (c *NetConn) SetBlockingMode(blocking bool) error {
// Fake non-blocking I/O is ignored
if c.rawConn == nil {
return nil
}

// Prevent concurrent SetBlockingMode calls
c.nbOperMu.Lock()
defer c.nbOperMu.Unlock()

// Guard against negative value (which should never happen in practice)
if c.nbOperCnt < 0 {
c.nbOperCnt = 0
}

if blocking {
// Socket is already in blocking mode
if c.nbOperCnt == 0 {
return nil
}

c.nbOperCnt--

// Not ready to exit from non-blocking mode, there is pending non-blocking operations
if c.nbOperCnt > 0 {
return nil
}
} else {
c.nbOperCnt++

// Socket is already in non-blocking mode
if c.nbOperCnt > 1 {
return nil
}
}

mode := sockModeNonBlocking
if blocking {
mode = sockModeBlocking
}

var ctrlErr, err error

ctrlErr = c.rawConn.Control(func(fd uintptr) {
err = setSockMode(fd, mode)
})

if ctrlErr != nil || err != nil {
retErr := ctrlErr
if retErr == nil {
retErr = err
}

// Revert counters inc/dec in case of error
if blocking {
c.nbOperCnt++

return fmt.Errorf("cannot set socket to blocking mode: %w", retErr)
} else {
c.nbOperCnt--

return fmt.Errorf("cannot set socket to non-blocking mode: %w", retErr)
}
}

return nil
}