Skip to content

Commit

Permalink
Merge ThreadSafeReader into ReadWaiter interface
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 6, 2023
1 parent 86c131f commit 0b5ff61
Show file tree
Hide file tree
Showing 11 changed files with 217 additions and 159 deletions.
26 changes: 22 additions & 4 deletions common/buf/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ import (
"sync/atomic"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)

type Buffer struct {
data []byte
start int
end int
refs int32
refs atomic.Int32
managed bool
closed bool
}
Expand Down Expand Up @@ -281,24 +283,40 @@ func (b *Buffer) FullReset() {
}

func (b *Buffer) IncRef() {
atomic.AddInt32(&b.refs, 1)
b.refs.Add(1)
}

func (b *Buffer) DecRef() {
atomic.AddInt32(&b.refs, -1)
b.refs.Add(-1)
}

func (b *Buffer) Release() {
if b == nil || b.closed || !b.managed {
return
}
if atomic.LoadInt32(&b.refs) > 0 {
if b.refs.Load() > 0 {
return
}
common.Must(Put(b.data))
*b = Buffer{closed: true}
}

func (b *Buffer) Leak() {
if debug.Enabled {
if b == nil || b.closed || !b.managed {
return
}
refs := b.refs.Load()
if refs == 0 {
panic("leaking buffer")
} else {
panic(F.ToString("leaking buffer with ", refs, " references"))
}
} else {
b.Release()
}
}

func (b *Buffer) Cut(start int, end int) *Buffer {
b.start += start
b.end = len(b.data) - end
Expand Down
8 changes: 4 additions & 4 deletions common/bufio/bind_wait.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ type BindPacketReadWaiter struct {
readWaiter N.PacketReadWaiter
}

func (w *BindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readWaiter.InitializeReadWaiter(newBuffer)
func (w *BindPacketReadWaiter) InitializeReadWaiter(options *N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}

func (w *BindPacketReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
Expand All @@ -28,8 +28,8 @@ type UnbindPacketReadWaiter struct {
addr M.Socksaddr
}

func (w *UnbindPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readWaiter.InitializeReadWaiter(newBuffer)
func (w *UnbindPacketReadWaiter) InitializeReadWaiter(options *N.ReadWaitOptions) (needCopy bool) {
return w.readWaiter.InitializeReadWaiter(options)
}

func (w *UnbindPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
Expand Down
48 changes: 36 additions & 12 deletions common/bufio/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,31 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
}

func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
//nolint:staticcheck
//goland:noinspection GoDeprecation
safeSrc := N.IsSafeReader(source)
headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
if safeSrc != nil {
if headroom == 0 {
if frontHeadroom == 0 && rearHeadroom == 0 {
//nolint:staticcheck
//goland:noinspection GoDeprecation
return CopyExtendedWithSrcBuffer(originSource, destination, safeSrc, readCounters, writeCounters)
}
}
readWaiter, isReadWaiter := CreateReadWaiter(source)
if isReadWaiter {
var handled bool
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
if handled {
return
needCopy := readWaiter.InitializeReadWaiter(&N.ReadWaitOptions{
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destination),
})
if !needCopy || common.LowMemory {
var handled bool
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
if handled {
return
}
}
}
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
Expand Down Expand Up @@ -113,6 +125,7 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
}
}

// Deprecated: Use ReadWaiter interface instead.
func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var notFirstTime bool
for {
Expand Down Expand Up @@ -256,13 +269,16 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
return
}
}
//nolint:staticcheck
//goland:noinspection GoDeprecation
safeSrc := N.IsSafePacketReader(source)
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
headroom := frontHeadroom + rearHeadroom
if safeSrc != nil {
if headroom == 0 {
if frontHeadroom == 0 && rearHeadroom == 0 {
var copyN int64
//nolint:staticcheck
//goland:noinspection GoDeprecation
copyN, err = CopyPacketWithSrcBuffer(originSource, destinationConn, safeSrc, readCounters, writeCounters, n > 0)
n += copyN
return
Expand All @@ -274,17 +290,25 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
)
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled {
n += copeN
return
needCopy := readWaiter.InitializeReadWaiter(&N.ReadWaitOptions{
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destinationConn),
})
if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled {
n += copeN
return
}
}
}
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
n += copeN
return
}

// Deprecated: Use PacketReadWaiter interface instead.
func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
var buffer *buf.Buffer
var destination M.Socksaddr
Expand Down
147 changes: 53 additions & 94 deletions common/bufio/copy_direct_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,23 @@ import (

func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
handled = true
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
bufferSize := N.CalculateMTU(source, destination)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
notFirstTime bool
)
source.InitializeReadWaiter(func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
})
defer source.InitializeReadWaiter(nil)
for {
_, err = source.WaitReadBuffer()
buffer, err = source.WaitReadBuffer()
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
dataLen := buffer.Len()
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Release()
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
Expand All @@ -70,37 +52,19 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour

func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
handled = true
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
bufferSize := N.CalculateMTU(source, destinationConn)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var (
buffer *buf.Buffer
readBuffer *buf.Buffer
destination M.Socksaddr
)
source.InitializeReadWaiter(func() *buf.Buffer {
buffer = buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
return readBuffer
})
defer source.InitializeReadWaiter(nil)
for {
_, destination, err = source.WaitReadPacket()
buffer, destination, err = source.WaitReadPacket()
if err != nil {
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
dataLen := buffer.Len()
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Release()
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
Expand All @@ -124,6 +88,7 @@ type syscallReadWaiter struct {
readErr error
readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer
options *N.ReadWaitOptions
}

func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
Expand All @@ -136,31 +101,28 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
return nil, false
}

func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readErr = nil
if newBuffer == nil {
w.readFunc = nil
} else {
w.readFunc = func(fd uintptr) (done bool) {
buffer := newBuffer()
var readN int
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
if readN > 0 {
buffer.Truncate(readN)
} else {
buffer.Release()
buffer = nil
}
if w.readErr == syscall.EAGAIN {
return false
}
if readN == 0 {
w.readErr = io.EOF
}
w.buffer = buffer
return true
func (w *syscallReadWaiter) InitializeReadWaiter(options *N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
buffer, readBuffer := w.options.NewBuffer()
var readN int
readN, w.readErr = syscall.Read(int(fd), readBuffer.FreeBytes())
if readN > 0 {
buffer.Resize(readBuffer.Start(), readN)
} else {
buffer.Release()
buffer = nil
}
if w.readErr == syscall.EAGAIN {
return false
}
if readN == 0 {
w.readErr = io.EOF
}
w.buffer = buffer
return true
}
return false
}

func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
Expand Down Expand Up @@ -190,6 +152,7 @@ type syscallPacketReadWaiter struct {
readFrom M.Socksaddr
readFunc func(fd uintptr) (done bool)
buffer *buf.Buffer
options *N.ReadWaitOptions
}

func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
Expand All @@ -202,38 +165,34 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
return nil, false
}

func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readErr = nil
w.readFrom = M.Socksaddr{}
if newBuffer == nil {
w.readFunc = nil
} else {
w.readFunc = func(fd uintptr) (done bool) {
buffer := newBuffer()
var readN int
var from syscall.Sockaddr
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
if readN > 0 {
buffer.Truncate(readN)
} else {
buffer.Release()
buffer = nil
}
if w.readErr == syscall.EAGAIN {
return false
}
if from != nil {
switch fromAddr := from.(type) {
case *syscall.SockaddrInet4:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *syscall.SockaddrInet6:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options *N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
buffer, readBuffer := w.options.NewBuffer()
var readN int
var from syscall.Sockaddr
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), readBuffer.FreeBytes(), nil, 0)
if readN > 0 {
buffer.Resize(readBuffer.Start(), readN)
} else {
buffer.Release()
buffer = nil
}
if w.readErr == syscall.EAGAIN {
return false
}
if from != nil {
switch fromAddr := from.(type) {
case *syscall.SockaddrInet4:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *syscall.SockaddrInet6:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
w.buffer = buffer
return true
}
w.buffer = buffer
return true
}
return false
}

func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
Expand Down
Loading

0 comments on commit 0b5ff61

Please sign in to comment.