Skip to content

Commit

Permalink
Implementation read waiter for socks5 UDP and UoT
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 6, 2023
1 parent d0cb357 commit c5c692f
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 1 deletion.
37 changes: 37 additions & 0 deletions common/uot/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/binary"
"io"
"net"
"os"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
Expand All @@ -13,11 +14,17 @@ import (
N "github.com/sagernet/sing/common/network"
)

var (
_ N.NetPacketConn = (*Conn)(nil)
_ N.PacketReadWaiter = (*Conn)(nil)
)

type Conn struct {
net.Conn
isConnect bool
destination M.Socksaddr
writer N.VectorisedWriter
newBuffer func() *buf.Buffer
}

func NewConn(conn net.Conn, request Request) *Conn {
Expand Down Expand Up @@ -141,6 +148,36 @@ func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.writer.WriteVectorised([]*buf.Buffer{header, buffer})
}

func (c *Conn) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
c.newBuffer = newBuffer
}

func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if c.newBuffer == nil {
return nil, M.Socksaddr{}, os.ErrInvalid
}
if c.isConnect {
destination = c.destination
} else {
destination, err = AddrParser.ReadAddrPort(c.Conn)
if err != nil {
return
}
}
var length uint16
err = binary.Read(c.Conn, binary.BigEndian, &length)
if err != nil {
return
}
buffer = c.newBuffer()
_, err = buffer.ReadFullFrom(c.Conn, int(length))
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, E.Cause(err, "UoT read")
}
return
}

func (c *Conn) NeedAdditionalReadDeadline() bool {
return true
}
Expand Down
3 changes: 2 additions & 1 deletion protocol/socks/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"strings"

"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
Expand Down Expand Up @@ -147,7 +148,7 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock
tcpConn.Close()
return nil, err
}
return NewAssociateConn(udpConn, address, tcpConn), nil
return NewAssociatePacketConn(bufio.NewUnbindPacketConn(udpConn), address, tcpConn), nil
}
return nil, os.ErrInvalid
}
Expand Down
10 changes: 10 additions & 0 deletions protocol/socks/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
Expand All @@ -17,6 +18,8 @@ import (
// | 2 | 1 | 1 | Variable | 2 | Variable |
// +----+------+------+----------+----------+----------+

var ErrInvalidPacket = E.New("socks5: invalid packet")

type AssociatePacketConn struct {
N.NetPacketConn
remoteAddr M.Socksaddr
Expand All @@ -31,6 +34,7 @@ func NewAssociatePacketConn(conn net.PacketConn, remoteAddr M.Socksaddr, underly
}
}

// Deprecated: NewAssociatePacketConn(bufio.NewUnbindPacketConn(conn), remoteAddr, underlying) instead.
func NewAssociateConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
return &AssociatePacketConn{
NetPacketConn: bufio.NewUnbindPacketConn(conn),
Expand All @@ -49,6 +53,9 @@ func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err erro
if err != nil {
return
}
if n < 3 {
return 0, nil, ErrInvalidPacket
}
c.remoteAddr = M.SocksaddrFromNet(addr)
reader := bytes.NewReader(p[3:n])
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
Expand Down Expand Up @@ -92,6 +99,9 @@ func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Sock
if err != nil {
return M.Socksaddr{}, err
}
if buffer.Len() < 3 {
return M.Socksaddr{}, ErrInvalidPacket
}
c.remoteAddr = destination
buffer.Advance(3)
destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
Expand Down
48 changes: 48 additions & 0 deletions protocol/socks/packet_wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package socks

import (
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

var _ N.PacketReadWaitCreator = (*AssociatePacketConn)(nil)

func (c *AssociatePacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(c.NetPacketConn)
if !isReadWaiter {
return nil, false
}
return &AssociatePacketReadWaiter{c, readWaiter}, true
}

var _ N.PacketReadWaiter = (*AssociatePacketReadWaiter)(nil)

type AssociatePacketReadWaiter struct {
conn *AssociatePacketConn
readWaiter N.PacketReadWaiter
}

func (w *AssociatePacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
w.readWaiter.InitializeReadWaiter(newBuffer)
}

func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
buffer, destination, err = w.readWaiter.WaitReadPacket()
if err != nil {
return
}
if buffer.Len() < 3 {
buffer.Release()
return nil, M.Socksaddr{}, ErrInvalidPacket
}
w.conn.remoteAddr = destination
buffer.Advance(3)
destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
buffer.Release()
return nil, M.Socksaddr{}, err
}
return
}

0 comments on commit c5c692f

Please sign in to comment.