From b93ea6f4d944a0c09d46656a2290117767546d0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 15 Dec 2023 20:43:44 +0800 Subject: [PATCH] Add TX checksum offload for Linux --- stack_mixed.go | 30 ++++++------- stack_system.go | 100 ++++++++++++++++++++++++++++---------------- tun.go | 6 ++- tun_darwin.go | 6 --- tun_linux.go | 49 ++++++++++++---------- tun_linux_flags.go | 84 +++++++++++++++++++++++++++++++++++++ tun_linux_gvisor.go | 2 + tun_windows.go | 4 -- 8 files changed, 197 insertions(+), 84 deletions(-) create mode 100644 tun_linux_flags.go diff --git a/stack_mixed.go b/stack_mixed.go index 7c1a8f9..811e0fd 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -91,17 +91,18 @@ func (m *Mixed) tunLoop() { m.wintunLoop(winTun) return } - if batchTUN, isBatchTUN := m.tun.(BatchTUN); isBatchTUN { - batchSize := batchTUN.BatchSize() + if linuxTUN, isLinuxTUN := m.tun.(LinuxTUN); isLinuxTUN { + m.frontHeadroom = linuxTUN.FrontHeadroom() + m.txChecksumOffload = linuxTUN.TXChecksumOffload() + batchSize := linuxTUN.BatchSize() if batchSize > 1 { - m.batchLoop(batchTUN, batchSize) + m.batchLoop(linuxTUN, batchSize) return } } - frontHeadroom := m.tun.FrontHeadroom() - packetBuffer := make([]byte, m.mtu+frontHeadroom+PacketOffset) + packetBuffer := make([]byte, m.mtu+PacketOffset) for { - n, err := m.tun.Read(packetBuffer[frontHeadroom:]) + n, err := m.tun.Read(packetBuffer) if err != nil { if E.IsClosed(err) { return @@ -111,8 +112,8 @@ func (m *Mixed) tunLoop() { if n < clashtcpip.IPv4PacketMinLength { continue } - rawPacket := packetBuffer[:frontHeadroom+n] - packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n] + rawPacket := packetBuffer[:n] + packet := packetBuffer[PacketOffset:n] if m.processPacket(packet) { _, err = m.tun.Write(rawPacket) if err != nil { @@ -142,16 +143,15 @@ func (m *Mixed) wintunLoop(winTun WinTun) { } } -func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) { - frontHeadroom := m.tun.FrontHeadroom() +func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) { packetBuffers := make([][]byte, batchSize) writeBuffers := make([][]byte, batchSize) packetSizes := make([]int, batchSize) for i := range packetBuffers { - packetBuffers[i] = make([]byte, m.mtu+frontHeadroom) + packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom) } for { - n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes) + n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes) if err != nil { if E.IsClosed(err) { return @@ -167,13 +167,13 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) { continue } packetBuffer := packetBuffers[i] - packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize] + packet := packetBuffer[m.frontHeadroom : m.frontHeadroom+packetSize] if m.processPacket(packet) { - writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize]) + writeBuffers = append(writeBuffers, packetBuffer[:m.frontHeadroom+packetSize]) } } if len(writeBuffers) > 0 { - err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom) + err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom) if err != nil { m.logger.Trace(E.Cause(err, "batch write packet")) } diff --git a/stack_system.go b/stack_system.go index 8b687fa..b340525 100644 --- a/stack_system.go +++ b/stack_system.go @@ -41,6 +41,8 @@ type System struct { udpNat *udpnat.Service[netip.AddrPort] bindInterface bool interfaceFinder control.InterfaceFinder + frontHeadroom int + txChecksumOffload bool } type Session struct { @@ -144,17 +146,18 @@ func (s *System) tunLoop() { s.wintunLoop(winTun) return } - if batchTUN, isBatchTUN := s.tun.(BatchTUN); isBatchTUN { - batchSize := batchTUN.BatchSize() + if linuxTUN, isLinuxTUN := s.tun.(LinuxTUN); isLinuxTUN { + s.frontHeadroom = linuxTUN.FrontHeadroom() + s.txChecksumOffload = linuxTUN.TXChecksumOffload() + batchSize := linuxTUN.BatchSize() if batchSize > 1 { - s.batchLoop(batchTUN, batchSize) + s.batchLoop(linuxTUN, batchSize) return } } - frontHeadroom := s.tun.FrontHeadroom() - packetBuffer := make([]byte, s.mtu+frontHeadroom+PacketOffset) + packetBuffer := make([]byte, s.mtu+PacketOffset) for { - n, err := s.tun.Read(packetBuffer[frontHeadroom:]) + n, err := s.tun.Read(packetBuffer) if err != nil { if E.IsClosed(err) { return @@ -164,8 +167,8 @@ func (s *System) tunLoop() { if n < clashtcpip.IPv4PacketMinLength { continue } - rawPacket := packetBuffer[:frontHeadroom+n] - packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n] + rawPacket := packetBuffer[:n] + packet := packetBuffer[PacketOffset:n] if s.processPacket(packet) { _, err = s.tun.Write(rawPacket) if err != nil { @@ -195,16 +198,15 @@ func (s *System) wintunLoop(winTun WinTun) { } } -func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) { - frontHeadroom := s.tun.FrontHeadroom() +func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) { packetBuffers := make([][]byte, batchSize) writeBuffers := make([][]byte, batchSize) packetSizes := make([]int, batchSize) for i := range packetBuffers { - packetBuffers[i] = make([]byte, s.mtu+frontHeadroom) + packetBuffers[i] = make([]byte, s.mtu+s.frontHeadroom) } for { - n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes) + n, err := linuxTUN.BatchRead(packetBuffers, s.frontHeadroom, packetSizes) if err != nil { if E.IsClosed(err) { return @@ -220,13 +222,13 @@ func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) { continue } packetBuffer := packetBuffers[i] - packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize] + packet := packetBuffer[s.frontHeadroom : s.frontHeadroom+packetSize] if s.processPacket(packet) { - writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize]) + writeBuffers = append(writeBuffers, packetBuffer[:s.frontHeadroom+packetSize]) } } if len(writeBuffers) > 0 { - err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom) + err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom) if err != nil { s.logger.Trace(E.Cause(err, "batch write packet")) } @@ -352,8 +354,10 @@ func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip. packet.SetDestinationIP(s.inet4ServerAddress) header.SetDestinationPort(s.tcpPort) } - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() + if !s.txChecksumOffload { + header.ResetChecksum(packet.PseudoSum()) + packet.ResetChecksum() + } return nil } @@ -378,8 +382,10 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip. packet.SetDestinationIP(s.inet6ServerAddress) header.SetDestinationPort(s.tcpPort6) } - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() + if !s.txChecksumOffload { + header.ResetChecksum(packet.PseudoSum()) + packet.ResetChecksum() + } return nil } @@ -410,7 +416,13 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip. headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) - return &systemUDPPacketWriter4{s.tun, s.tun.FrontHeadroom() + PacketOffset, headerCopy, source} + return &systemUDPPacketWriter4{ + s.tun, + s.frontHeadroom + PacketOffset, + headerCopy, + source, + s.txChecksumOffload, + } }) return nil } @@ -436,7 +448,13 @@ func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip. headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) - return &systemUDPPacketWriter6{s.tun, s.tun.FrontHeadroom() + PacketOffset, headerCopy, source} + return &systemUDPPacketWriter6{ + s.tun, + s.frontHeadroom + PacketOffset, + headerCopy, + source, + s.txChecksumOffload, + } }) return nil } @@ -449,8 +467,10 @@ func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip sourceAddress := packet.SourceIP() packet.SetSourceIP(packet.DestinationIP()) packet.SetDestinationIP(sourceAddress) - header.ResetChecksum() - packet.ResetChecksum() + if !s.txChecksumOffload { + header.ResetChecksum() + packet.ResetChecksum() + } return nil } @@ -462,16 +482,19 @@ func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip sourceAddress := packet.SourceIP() packet.SetSourceIP(packet.DestinationIP()) packet.SetDestinationIP(sourceAddress) - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() + if !s.txChecksumOffload { + header.ResetChecksum(packet.PseudoSum()) + packet.ResetChecksum() + } return nil } type systemUDPPacketWriter4 struct { - tun Tun - frontHeadroom int - header []byte - source netip.AddrPort + tun Tun + frontHeadroom int + header []byte + source netip.AddrPort + txChecksumOffload bool } func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { @@ -488,8 +511,10 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S udpHdr.SetDestinationPort(udpHdr.SourcePort()) udpHdr.SetSourcePort(destination.Port) udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize)) - udpHdr.ResetChecksum(ipHdr.PseudoSum()) - ipHdr.ResetChecksum() + if !w.txChecksumOffload { + udpHdr.ResetChecksum(ipHdr.PseudoSum()) + ipHdr.ResetChecksum() + } if PacketOffset > 0 { newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET } else { @@ -499,10 +524,11 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S } type systemUDPPacketWriter6 struct { - tun Tun - frontHeadroom int - header []byte - source netip.AddrPort + tun Tun + frontHeadroom int + header []byte + source netip.AddrPort + txChecksumOffload bool } func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { @@ -520,7 +546,9 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S udpHdr.SetDestinationPort(udpHdr.SourcePort()) udpHdr.SetSourcePort(destination.Port) udpHdr.SetLength(udpLen) - udpHdr.ResetChecksum(ipHdr.PseudoSum()) + if !w.txChecksumOffload { + udpHdr.ResetChecksum(ipHdr.PseudoSum()) + } if PacketOffset > 0 { newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 } else { diff --git a/tun.go b/tun.go index 9610782..7a94a38 100644 --- a/tun.go +++ b/tun.go @@ -24,7 +24,6 @@ type Handler interface { type Tun interface { io.ReadWriter N.VectorisedWriter - N.FrontHeadroom Close() error } @@ -33,11 +32,13 @@ type WinTun interface { ReadPacket() ([]byte, func(), error) } -type BatchTUN interface { +type LinuxTUN interface { Tun + N.FrontHeadroom BatchSize() int BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) BatchWrite(buffers [][]byte, offset int) error + TXChecksumOffload() bool } type Options struct { @@ -46,6 +47,7 @@ type Options struct { Inet6Address []netip.Prefix MTU uint32 GSO bool + TXChecksumOffload bool AutoRoute bool StrictRoute bool Inet4RouteAddress []netip.Prefix diff --git a/tun_darwin.go b/tun_darwin.go index 4436026..26782f0 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -5,7 +5,6 @@ import ( "net" "net/netip" "os" - "runtime" "syscall" "unsafe" @@ -68,14 +67,9 @@ func New(options Options) (Tun, error) { if !ok { panic("create vectorised writer") } - runtime.SetFinalizer(nativeTun.tunFile, nil) return nativeTun, nil } -func (t *NativeTun) FrontHeadroom() int { - return 0 -} - func (t *NativeTun) Read(p []byte) (n int, err error) { return t.tunFile.Read(p) } diff --git a/tun_linux.go b/tun_linux.go index c7072b3..511c18f 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -24,7 +24,7 @@ import ( "golang.org/x/sys/unix" ) -var _ BatchTUN = (*NativeTun)(nil) +var _ LinuxTUN = (*NativeTun)(nil) type NativeTun struct { tunFd int @@ -40,6 +40,7 @@ type NativeTun struct { tcpGROAccess sync.Mutex tcp4GROTable *tcpGROTable tcp6GROTable *tcpGROTable + txChecksumOffload bool } func New(options Options) (Tun, error) { @@ -246,20 +247,17 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { } if t.options.GSO { - vnethdrEnabled, err := checkVNETHDREnabled(uint16(t.tunFd), t.options.Name) + var vnetHdrEnabled bool + vnetHdrEnabled, err = checkVNETHDREnabled(t.tunFd, t.options.Name) if err != nil { return E.Cause(err, "enable offload: check IFF_VNET_HDR enabled") } - if !vnethdrEnabled { + if !vnetHdrEnabled { return E.Cause(err, "enable offload: IFF_VNET_HDR not enabled") } - const ( - // TODO: support TSO with ECN bits - tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 - ) - err = unix.IoctlSetInt(t.tunFd, unix.TUNSETOFFLOAD, tunOffloads) + err = setTCPOffload(t.tunFd) if err != nil { - return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload") + return err } t.gsoEnabled = true t.gsoBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize)) @@ -267,6 +265,23 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { t.tcp6GROTable = newTCPGROTable() } + if !t.options.TXChecksumOffload { + t.txChecksumOffload, _ = checkTxChecksumOffload(t.options.Name) + } else { + var txChecksumOffload bool + txChecksumOffload, err = checkTxChecksumOffload(t.options.Name) + if err != nil { + return err + } + if !txChecksumOffload { + err = setTxChecksumOffload(t.options.Name) + if err != nil { + return err + } + } + t.txChecksumOffload = true + } + err = netlink.LinkSetUp(tunLink) if err != nil { return err @@ -306,18 +321,6 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { return nil } -func checkVNETHDREnabled(fd uint16, name string) (bool, error) { - ifr, err := unix.NewIfreq(name) - if err != nil { - return false, err - } - err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) - if err != nil { - return false, os.NewSyscallError("TUNGETIFF", err) - } - return ifr.Uint16()&unix.IFF_VNET_HDR != 0, nil -} - func (t *NativeTun) Close() error { if t.interfaceCallback != nil { t.options.InterfaceMonitor.UnregisterCallback(t.interfaceCallback) @@ -325,6 +328,10 @@ func (t *NativeTun) Close() error { return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile))) } +func (t *NativeTun) TXChecksumOffload() bool { + return t.txChecksumOffload +} + func prefixToIPNet(prefix netip.Prefix) *net.IPNet { return &net.IPNet{ IP: prefix.Addr().AsSlice(), diff --git a/tun_linux_flags.go b/tun_linux_flags.go new file mode 100644 index 0000000..c2a14a7 --- /dev/null +++ b/tun_linux_flags.go @@ -0,0 +1,84 @@ +//go:build linux + +package tun + +import ( + "os" + "syscall" + "unsafe" + + E "github.com/sagernet/sing/common/exceptions" + + "golang.org/x/sys/unix" +) + +func checkVNETHDREnabled(fd int, name string) (bool, error) { + ifr, err := unix.NewIfreq(name) + if err != nil { + return false, err + } + err = unix.IoctlIfreq(fd, unix.TUNGETIFF, ifr) + if err != nil { + return false, os.NewSyscallError("TUNGETIFF", err) + } + return ifr.Uint16()&unix.IFF_VNET_HDR != 0, nil +} + +func setTCPOffload(fd int) error { + const ( + // TODO: support TSO with ECN bits + tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + ) + err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunOffloads) + if err != nil { + return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload") + } + return nil +} + +type ifreqData struct { + ifrName [unix.IFNAMSIZ]byte + ifrData uintptr +} + +type ethtoolValue struct { + cmd uint32 + data uint32 +} + +//go:linkname ioctlPtr golang.org/x/sys/unix.ioctlPtr +func ioctlPtr(fd int, req uint, arg unsafe.Pointer) (err error) + +func checkTxChecksumOffload(name string) (bool, error) { + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return false, err + } + defer syscall.Close(fd) + ifr := ifreqData{} + copy(ifr.ifrName[:], name) + data := ethtoolValue{cmd: unix.ETHTOOL_GTXCSUM} + ifr.ifrData = uintptr(unsafe.Pointer(&data)) + err = ioctlPtr(fd, unix.SIOCETHTOOL, unsafe.Pointer(&ifr)) + if err != nil { + return false, os.NewSyscallError("SIOCETHTOOL ETHTOOL_GTXCSUM", err) + } + return data.data == 1, nil +} + +func setTxChecksumOffload(name string) error { + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(fd) + ifr := ifreqData{} + copy(ifr.ifrName[:], name) + data := ethtoolValue{cmd: unix.ETHTOOL_STXCSUM, data: 1} + ifr.ifrData = uintptr(unsafe.Pointer(&data)) + err = ioctlPtr(fd, unix.SIOCETHTOOL, unsafe.Pointer(&ifr)) + if err != nil { + return os.NewSyscallError("SIOCETHTOOL ETHTOOL_STXCSUM", err) + } + return nil +} diff --git a/tun_linux_gvisor.go b/tun_linux_gvisor.go index 8f044ca..ea07e9c 100644 --- a/tun_linux_gvisor.go +++ b/tun_linux_gvisor.go @@ -15,11 +15,13 @@ func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) { FDs: []int{t.tunFd}, MTU: t.options.MTU, RXChecksumOffload: true, + TXChecksumOffload: t.txChecksumOffload, }) } return fdbased.New(&fdbased.Options{ FDs: []int{t.tunFd}, MTU: t.options.MTU, RXChecksumOffload: true, + TXChecksumOffload: t.txChecksumOffload, }) } diff --git a/tun_windows.go b/tun_windows.go index 90a9867..7e1a0c3 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -65,10 +65,6 @@ func New(options Options) (WinTun, error) { return nativeTun, nil } -func (t *NativeTun) FrontHeadroom() int { - return 0 -} - func (t *NativeTun) configure() error { luid := winipcfg.LUID(t.adapter.LUID()) if len(t.options.Inet4Address) > 0 {