Skip to content

Commit

Permalink
Enable read wait copy for windows
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 7, 2023
1 parent 0313eb6 commit ed57e5f
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 79 deletions.
2 changes: 1 addition & 1 deletion common/bufio/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
Expand Down
72 changes: 71 additions & 1 deletion common/bufio/copy_direct.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package bufio

import (
"errors"
"io"
"syscall"

"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
rawSource, err := source.SyscallConn()
if err != nil {
return
Expand All @@ -18,3 +22,69 @@ func CopyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.
handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
return
}

func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
notFirstTime bool
)
for {
buffer, err = source.WaitReadBuffer()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := buffer.Len()
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}

func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
destination M.Socksaddr
)
for {
buffer, destination, err = source.WaitReadPacket()
if err != nil {
return
}
dataLen := buffer.Len()
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}
67 changes: 0 additions & 67 deletions common/bufio/copy_direct_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package bufio

import (
"errors"
"io"
"net/netip"
"os"
Expand All @@ -15,72 +14,6 @@ import (
N "github.com/sagernet/sing/common/network"
)

func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
notFirstTime bool
)
for {
buffer, err = source.WaitReadBuffer()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := buffer.Len()
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}

func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
handled = true
var (
buffer *buf.Buffer
destination M.Socksaddr
)
for {
buffer, destination, err = source.WaitReadPacket()
if err != nil {
return
}
dataLen := buffer.Len()
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
notFirstTime = true
}
}

var _ N.ReadWaiter = (*syscallReadWaiter)(nil)

type syscallReadWaiter struct {
Expand Down
10 changes: 0 additions & 10 deletions common/bufio/copy_direct_windows.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
package bufio

import (
"io"

N "github.com/sagernet/sing/common/network"
)

func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
return
}

func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
return
}

func createSyscallReadWaiter(reader any) (N.ReadWaiter, bool) {
return nil, false
}
Expand Down

0 comments on commit ed57e5f

Please sign in to comment.