Skip to content

Commit

Permalink
Add reserve support for buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 7, 2023
1 parent ceb148c commit 187a68e
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 83 deletions.
80 changes: 43 additions & 37 deletions common/buf/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"crypto/rand"
"io"
"net"
"strconv"
"sync/atomic"

"github.com/sagernet/sing/common"
Expand All @@ -17,21 +16,23 @@ type Buffer struct {
data []byte
start int
end int
length int
refs atomic.Int32
managed bool
closed bool
}

func New() *Buffer {
return &Buffer{
data: Get(BufferSize),
length: BufferSize,
managed: true,
}
}

func NewPacket() *Buffer {
return &Buffer{
data: Get(UDPBufferSize),
length: UDPBufferSize,
managed: true,
}
}
Expand All @@ -41,40 +42,29 @@ func NewSize(size int) *Buffer {
return &Buffer{}
} else if size > 65535 {
return &Buffer{
data: make([]byte, size),
data: make([]byte, size),
length: size,
}
}
return &Buffer{
data: Get(size),
length: size,
managed: true,
}
}

// Deprecated: use New instead.
func StackNew() *Buffer {
return New()
}

// Deprecated: use NewPacket instead.
func StackNewPacket() *Buffer {
return NewPacket()
}

// Deprecated: use NewSize instead.
func StackNewSize(size int) *Buffer {
return NewSize(size)
}

func As(data []byte) *Buffer {
return &Buffer{
data: data,
end: len(data),
data: data,
end: len(data),
length: len(data),
}
}

func With(data []byte) *Buffer {
return &Buffer{
data: data,
data: data,
length: len(data),
}
}

Expand All @@ -88,8 +78,8 @@ func (b *Buffer) SetByte(index int, value byte) {

func (b *Buffer) Extend(n int) []byte {
end := b.end + n
if end > cap(b.data) {
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",end " + strconv.Itoa(b.end) + ", need " + strconv.Itoa(n))
if end > b.length {
panic(F.ToString("buffer overflow: length ", b.length, ",end ", b.end, ", need ", n))
}
ext := b.data[b.end:end]
b.end = end
Expand All @@ -111,14 +101,14 @@ func (b *Buffer) Write(data []byte) (n int, err error) {
if b.IsFull() {
return 0, io.ErrShortBuffer
}
n = copy(b.data[b.end:], data)
n = copy(b.data[b.end:b.length], data)
b.end += n
return
}

func (b *Buffer) ExtendHeader(n int) []byte {
if b.start < n {
panic("buffer overflow: cap " + strconv.Itoa(cap(b.data)) + ",start " + strconv.Itoa(b.start) + ", need " + strconv.Itoa(n))
panic(F.ToString("buffer overflow: length ", b.length, ",start ", b.start, ", need ", n))
}
b.start -= n
return b.data[b.start : b.start+n]
Expand Down Expand Up @@ -171,7 +161,7 @@ func (b *Buffer) ReadAtLeastFrom(r io.Reader, min int) (int64, error) {
}

func (b *Buffer) ReadFullFrom(r io.Reader, size int) (n int, err error) {
if b.end+size > b.Cap() {
if b.end+size > b.length {
return 0, io.ErrShortBuffer
}
n, err = io.ReadFull(r, b.data[b.end:b.end+size])
Expand Down Expand Up @@ -208,7 +198,7 @@ func (b *Buffer) WriteString(s string) (n int, err error) {
if b.IsFull() {
return 0, io.ErrShortBuffer
}
n = copy(b.data[b.end:], s)
n = copy(b.data[b.end:b.length], s)
b.end += n
return
}
Expand All @@ -223,7 +213,7 @@ func (b *Buffer) WriteZero() error {
}

func (b *Buffer) WriteZeroN(n int) error {
if b.end+n > b.Cap() {
if b.end+n > b.length {
return io.ErrShortBuffer
}
for i := b.end; i <= b.end+n; i++ {
Expand Down Expand Up @@ -272,9 +262,24 @@ func (b *Buffer) Resize(start, end int) {
b.end = b.start + end
}

func (b *Buffer) Reserve(n int) {
if n > b.length {
panic(F.ToString("buffer overflow: length ", b.length, ", need ", n))
}
b.length -= n
}

func (b *Buffer) OverLength(n int) {
if b.length+n > len(b.data) {
panic(F.ToString("buffer overflow: length ", len(b.data), ", need ", b.length+n))
}
b.length += n
}

func (b *Buffer) Reset() {
b.start = 0
b.end = 0
b.length = len(b.data)
}

// Deprecated: use Reset instead.
Expand All @@ -291,19 +296,19 @@ func (b *Buffer) DecRef() {
}

func (b *Buffer) Release() {
if b == nil || b.closed || !b.managed {
if b == nil || !b.managed {
return
}
if b.refs.Load() > 0 {
return
}
common.Must(Put(b.data))
*b = Buffer{closed: true}
*b = Buffer{}
}

func (b *Buffer) Leak() {
if debug.Enabled {
if b == nil || b.closed || !b.managed {
if b == nil || !b.managed {
return
}
refs := b.refs.Load()
Expand All @@ -319,7 +324,7 @@ func (b *Buffer) Leak() {

func (b *Buffer) Cut(start int, end int) *Buffer {
b.start += start
b.end = len(b.data) - end
b.end = b.length - end
return &Buffer{
data: b.data[b.start:b.end],
}
Expand All @@ -334,15 +339,15 @@ func (b *Buffer) Len() int {
}

func (b *Buffer) Cap() int {
return len(b.data)
return b.length
}

func (b *Buffer) Bytes() []byte {
return b.data[b.start:b.end]
}

func (b *Buffer) Slice() []byte {
return b.data
return b.data[:b.length]
}

func (b *Buffer) From(n int) []byte {
Expand All @@ -362,25 +367,26 @@ func (b *Buffer) Index(start int) []byte {
}

func (b *Buffer) FreeLen() int {
return b.Cap() - b.end
return b.length - b.end
}

func (b *Buffer) FreeBytes() []byte {
return b.data[b.end:b.Cap()]
return b.data[b.end:b.length]
}

func (b *Buffer) IsEmpty() bool {
return b.end-b.start == 0
}

func (b *Buffer) IsFull() bool {
return b.end == b.Cap()
return b.end == b.length
}

func (b *Buffer) ToOwned() *Buffer {
n := NewSize(len(b.data))
copy(n.data[b.start:b.end], b.data[b.start:b.end])
n.start = b.start
n.end = b.end
n.length = b.length
return n
}
33 changes: 15 additions & 18 deletions common/bufio/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,20 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
defer buffer.DecRef()
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
var notFirstTime bool
for {
readBuffer.Resize(frontHeadroom, 0)
err = source.ReadBuffer(readBuffer)
err = source.ReadBuffer(buffer)
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
return
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
dataLen := buffer.Len()
buffer.OverLength(rearHeadroom)
err = destination.WriteBuffer(buffer)
if err != nil {
if !notFirstTime {
Expand Down Expand Up @@ -126,10 +125,9 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
var notFirstTime bool
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
err = source.ReadBuffer(readBuffer)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
err = source.ReadBuffer(buffer)
if err != nil {
buffer.Release()
if errors.Is(err, io.EOF) {
Expand All @@ -138,8 +136,8 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
}
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
dataLen := buffer.Len()
buffer.OverLength(rearHeadroom)
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
Expand Down Expand Up @@ -263,16 +261,15 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
var destination M.Socksaddr
for {
buffer := buf.NewSize(bufferSize)
readBufferRaw := buffer.Slice()
readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
readBuffer.Resize(frontHeadroom, 0)
destination, err = source.ReadPacket(readBuffer)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
destination, err = source.ReadPacket(buffer)
if err != nil {
buffer.Release()
return
}
dataLen := readBuffer.Len()
buffer.Resize(readBuffer.Start(), dataLen)
dataLen := buffer.Len()
buffer.OverLength(rearHeadroom)
err = destinationConn.WritePacket(buffer, destination)
if err != nil {
buffer.Leak()
Expand Down
14 changes: 8 additions & 6 deletions common/bufio/copy_direct_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
buffer, readBuffer := w.options.NewBuffer()
buffer := w.options.NewBuffer()
var readN int
readN, w.readErr = syscall.Read(int(fd), readBuffer.FreeBytes())
readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
if readN > 0 {
buffer.Resize(readBuffer.Start(), readN)
buffer.Truncate(readN)
} else {
buffer.Release()
buffer = nil
Expand All @@ -119,6 +119,7 @@ func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (nee
if readN == 0 {
w.readErr = io.EOF
}
w.options.PostReturn(buffer)
w.buffer = buffer
return true
}
Expand Down Expand Up @@ -168,12 +169,12 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
buffer, readBuffer := w.options.NewPacketBuffer()
buffer := w.options.NewPacketBuffer()
var readN int
var from syscall.Sockaddr
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), readBuffer.FreeBytes(), nil, 0)
readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
if readN > 0 {
buffer.Resize(readBuffer.Start(), readN)
buffer.Truncate(readN)
} else {
buffer.Release()
buffer = nil
Expand All @@ -189,6 +190,7 @@ func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
}
w.options.PostReturn(buffer)
w.buffer = buffer
return true
}
Expand Down
Loading

0 comments on commit 187a68e

Please sign in to comment.