Skip to content

Commit

Permalink
Merge ThreadSafeReader into ReadWaiter interface
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 6, 2023
1 parent 86c131f commit 25b4d58
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 35 deletions.
8 changes: 4 additions & 4 deletions common/bufio/bind_wait.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ type BindPacketReadWaiter struct {
readWaiter N.PacketReadWaiter
}

func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readWaiter.InitializeReadWaiter(newBuffer)
func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer, needHeadroom bool) bool {
return w.readWaiter.InitializeReadWaiter(newBuffer, needHeadroom)
}

func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
Expand All @@ -28,8 +28,8 @@ type UnbindPacketReadWaiter struct {
addr M.Socksaddr
}

func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readWaiter.InitializeReadWaiter(newBuffer)
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer, needHeadroom bool) bool {
return w.readWaiter.InitializeReadWaiter(newBuffer, needHeadroom)
}

func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
Expand Down
14 changes: 12 additions & 2 deletions common/bufio/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,19 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
}

func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
//nolint:staticcheck
//goland:noinspection GoDeprecation
safeSrc := N.IsSafeReader(source)
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
if safeSrc != nil {
if headroom == 0 {
//nolint:staticcheck
//goland:noinspection GoDeprecation
return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters)
}
}
readWaiter, isReadWaiter := CreateReadWaiter(source)
if isReadWaiter {
if isReadWaiter && (readWaiter.InitializeReadWaiter(nil, headroom > 0) || headroom == 0 || common.LowMemory) {
var handled bool
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
if handled {
Expand Down Expand Up @@ -113,6 +117,7 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
}
}

// Deprecated: Use ReadWaiter interface instead.
func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var notFirstTime bool
for {
Expand Down Expand Up @@ -256,13 +261,17 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
return
}
}
//nolint:staticcheck
//goland:noinspection GoDeprecation
safeSrc := N.IsSafePacketReader(source)
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
headroom := frontHeadroom + rearHeadroom
if safeSrc != nil {
if headroom == 0 {
var copyN int64
//nolint:staticcheck
//goland:noinspection GoDeprecation
copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0)
n += copyN
return
Expand All @@ -273,7 +282,7 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
copeN int64
)
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter {
if isReadWaiter && (readWaiter.InitializeReadWaiter(nil, headroom > 0) || headroom == 0 || common.LowMemory) {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled {
n += copeN
Expand All @@ -285,6 +294,7 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
return
}

// Deprecated: Use PacketReadWaiter interface instead.
func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
var buffer *buf.Buffer
var destination M.Socksaddr
Expand Down
78 changes: 55 additions & 23 deletions common/bufio/copy_direct_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
handled = true
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
needHeadroom := frontHeadroom > 0 || rearHeadroom > 0
bufferSize := N.CalculateMTU(source, destination)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
Expand All @@ -27,31 +28,45 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
resultBuffer *buf.Buffer
notFirstTime bool
)
source.InitializeReadWaiter(func() *buf.Buffer {
externalBuffer := source.InitializeReadWaiter(func() *buf.Buffer {
if buffer != nil {
buffer.Release()
}
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
})
defer source.InitializeReadWaiter(nil)
}, needHeadroom)
defer source.InitializeReadWaiter(nil, false)
for {
_, err = source.WaitReadBuffer()
resultBuffer, err = source.WaitReadBuffer()
if err != nil {
if buffer != nil {
buffer.Release()
}
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destination.WriteBuffer(buffer)
dataLen := resultBuffer.Len()
if externalBuffer {
err = destination.WriteBuffer(resultBuffer)
} else {
buffer.Resize(resultBuffer.Start(), dataLen)
err = destination.WriteBuffer(buffer)
}
if err != nil {
buffer.Release()
if externalBuffer {
resultBuffer.Release()
} else {
buffer.Release()
}
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
Expand All @@ -72,35 +87,50 @@ func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.Packe
handled = true
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
needHeadroom := frontHeadroom > 0 || rearHeadroom > 0
bufferSize := N.CalculateMTU(source, destinationConn)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
destination M.Socksaddr
buffer *buf.Buffer
resultBuffer *buf.Buffer
destination M.Socksaddr
)
source.InitializeReadWaiter(func() *buf.Buffer {
externalBuffer := source.InitializeReadWaiter(func() *buf.Buffer {
if buffer != nil {
buffer.Release()
}
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
})
defer source.InitializeReadWaiter(nil)
}, needHeadroom)
defer source.InitializeReadWaiter(nil, false)
for {
_, destination, err = source.WaitReadPacket()
if err != nil {
if buffer != nil {
buffer.Release()
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
err = destinationConn.WritePacket(buffer, destination)
dataLen := resultBuffer.Len()
if externalBuffer {
err = destinationConn.WritePacket(resultBuffer, destination)
} else {
buffer.Resize(resultBuffer.Start(), dataLen)
err = destinationConn.WritePacket(buffer, destination)
}
if err != nil {
buffer.Release()
if externalBuffer {
resultBuffer.Release()
} else {
buffer.Release()
}
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
Expand Down Expand Up @@ -136,7 +166,7 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
return nil, false
}

func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer, _ bool) bool {
w.readErr = nil
if newBuffer == nil {
w.readFunc = nil
Expand All @@ -161,6 +191,7 @@ func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
return true
}
}
return true
}

func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
Expand Down Expand Up @@ -202,7 +233,7 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
return nil, false
}

func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer, _ bool) bool {
w.readErr = nil
w.readFrom = M.Socksaddr{}
if newBuffer == nil {
Expand Down Expand Up @@ -234,6 +265,7 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buf
return true
}
}
return true
}

func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
Expand Down
8 changes: 6 additions & 2 deletions common/network/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ import (
M "github.com/sagernet/sing/common/metadata"
)

type ReadWaitable interface {
InitializeReadWaiter(newBuffer func() *buf.Buffer, needHeadroom bool) (externalBuffer bool)
}

type ReadWaiter interface {
InitializeReadWaiter(newBuffer func() *buf.Buffer)
ReadWaitable
WaitReadBuffer() (buffer *buf.Buffer, err error)
}

Expand All @@ -15,7 +19,7 @@ type ReadWaitCreator interface {
}

type PacketReadWaiter interface {
InitializeReadWaiter(newBuffer func() *buf.Buffer)
ReadWaitable
WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error)
}

Expand Down
6 changes: 6 additions & 0 deletions common/network/thread.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ type ThreadUnsafeWriter interface {
WriteIsThreadUnsafe()
}

// Deprecated: Use ReadWaiter interface instead.
type ThreadSafeReader interface {
// Deprecated: Use ReadWaiter interface instead.
ReadBufferThreadSafe() (buffer *buf.Buffer, err error)
}

// Deprecated: Use ReadWaiter interface instead.
type ThreadSafePacketReader interface {
// Deprecated: Use ReadWaiter interface instead.
ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error)
}

Expand All @@ -23,6 +27,7 @@ func IsUnsafeWriter(writer any) bool {
return isUnsafe
}

// Deprecated: Use ReadWaiter interface instead.
func IsSafeReader(reader any) ThreadSafeReader {
if safeReader, isSafe := reader.(ThreadSafeReader); isSafe {
return safeReader
Expand All @@ -39,6 +44,7 @@ func IsSafeReader(reader any) ThreadSafeReader {
return nil
}

// Deprecated: Use ReadWaiter interface instead.
func IsSafePacketReader(reader any) ThreadSafePacketReader {
if safeReader, isSafe := reader.(ThreadSafePacketReader); isSafe {
return safeReader
Expand Down
3 changes: 2 additions & 1 deletion common/pipe/pipe_wait.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ import (

var _ N.ReadWaiter = (*pipe)(nil)

func (p *pipe) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
func (p *pipe) InitializeReadWaiter(newBuffer func() *buf.Buffer, _ bool) bool {
p.newBuffer = newBuffer
return true
}

func (p *pipe) WaitReadBuffer() (buffer *buf.Buffer, err error) {
Expand Down
3 changes: 2 additions & 1 deletion common/uot/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.writer.WriteVectorised([]*buf.Buffer{header, buffer})
}

func (c *Conn) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
func (c *Conn) InitializeReadWaiter(newBuffer func() *buf.Buffer, _ bool) bool {
c.newBuffer = newBuffer
return true
}

func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
Expand Down
5 changes: 3 additions & 2 deletions protocol/socks/packet_wait.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ type AssociatePacketReadWaiter struct {
readWaiter N.PacketReadWaiter
}

func (w *AssociatePacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readWaiter.InitializeReadWaiter(newBuffer)
func (w *AssociatePacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer, needHeadroom bool) bool {
w.readWaiter.InitializeReadWaiter(newBuffer, needHeadroom)
return true
}

func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
Expand Down

0 comments on commit 25b4d58

Please sign in to comment.