diff --git a/stack_mixed.go b/stack_mixed.go index b52c996..7c1a8f9 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -145,15 +145,13 @@ func (m *Mixed) wintunLoop(winTun WinTun) { func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) { frontHeadroom := m.tun.FrontHeadroom() packetBuffers := make([][]byte, batchSize) - readBuffers := make([][]byte, batchSize) writeBuffers := make([][]byte, batchSize) packetSizes := make([]int, batchSize) for i := range packetBuffers { - packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset) - readBuffers[i] = packetBuffers[i][frontHeadroom:] + packetBuffers[i] = make([]byte, m.mtu+frontHeadroom) } for { - n, err := linuxTUN.BatchRead(readBuffers, packetSizes) + n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes) if err != nil { if E.IsClosed(err) { return @@ -169,13 +167,13 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) { continue } packetBuffer := packetBuffers[i] - packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize] + packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize] if m.processPacket(packet) { writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize]) } } if len(writeBuffers) > 0 { - err = linuxTUN.BatchWrite(writeBuffers) + err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom) if err != nil { m.logger.Trace(E.Cause(err, "batch write packet")) } diff --git a/stack_system.go b/stack_system.go index 73a83ae..8b687fa 100644 --- a/stack_system.go +++ b/stack_system.go @@ -198,15 +198,13 @@ func (s *System) wintunLoop(winTun WinTun) { func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) { frontHeadroom := s.tun.FrontHeadroom() packetBuffers := make([][]byte, batchSize) - readBuffers := make([][]byte, batchSize) writeBuffers := make([][]byte, batchSize) packetSizes := make([]int, batchSize) for i := range packetBuffers { - packetBuffers[i] = make([]byte, s.mtu+frontHeadroom+PacketOffset) - readBuffers[i] = packetBuffers[i][frontHeadroom:] + packetBuffers[i] = make([]byte, s.mtu+frontHeadroom) } for { - n, err := linuxTUN.BatchRead(readBuffers, packetSizes) + n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes) if err != nil { if E.IsClosed(err) { return @@ -222,13 +220,13 @@ func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) { continue } packetBuffer := packetBuffers[i] - packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize] + packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize] if s.processPacket(packet) { writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize]) } } if len(writeBuffers) > 0 { - err = linuxTUN.BatchWrite(writeBuffers) + err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom) if err != nil { s.logger.Trace(E.Cause(err, "batch write packet")) } diff --git a/tun.go b/tun.go index 5d853ee..9610782 100644 --- a/tun.go +++ b/tun.go @@ -36,8 +36,8 @@ type WinTun interface { type BatchTUN interface { Tun BatchSize() int - BatchRead(buffers [][]byte, readN []int) (n int, err error) - BatchWrite(buffers [][]byte) error + BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) + BatchWrite(buffers [][]byte, offset int) error } type Options struct { diff --git a/tun_linux.go b/tun_linux.go index 31903ef..c7072b3 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -35,6 +35,8 @@ type NativeTun struct { ruleIndex6 []int gsoEnabled bool gsoBuffer []byte + gsoToWrite []int + gsoReadAccess sync.Mutex tcpGROAccess sync.Mutex tcp4GROTable *tcpGROTable tcp6GROTable *tcpGROTable @@ -105,7 +107,7 @@ func (t *NativeTun) Read(p []byte) (n int, err error) { func (t *NativeTun) Write(p []byte) (n int, err error) { if t.gsoEnabled { - err = t.BatchWrite([][]byte{p}) + err = t.BatchWrite([][]byte{p}, virtioNetHdrLen) if err != nil { return } @@ -140,37 +142,31 @@ func (t *NativeTun) BatchSize() int { return batchSize } -func (t *NativeTun) BatchRead(buffers [][]byte, readN []int) (n int, err error) { - if t.gsoEnabled { - n, err = t.tunFile.Read(t.gsoBuffer) - if err != nil { - return - } - n, err = handleVirtioRead(t.gsoBuffer[:n], buffers, readN, 0) - if err != nil { - return - } - +func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) { + t.gsoReadAccess.Lock() + defer t.gsoReadAccess.Unlock() + n, err = t.tunFile.Read(t.gsoBuffer) + if err != nil { return - } else { - return 0, os.ErrInvalid } + return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset) } -func (t *NativeTun) BatchWrite(buffers [][]byte) error { +func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error { t.tcpGROAccess.Lock() defer func() { t.tcp4GROTable.reset() t.tcp6GROTable.reset() t.tcpGROAccess.Unlock() }() - var toWrite []int - err := handleGRO(buffers, virtioNetHdrLen, t.tcp4GROTable, t.tcp6GROTable, &toWrite) + t.gsoToWrite = t.gsoToWrite[:0] + err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite) if err != nil { return err } - for _, bufferIndex := range toWrite { - _, err = t.tunFile.Write(buffers[bufferIndex]) + offset -= virtioNetHdrLen + for _, bufferIndex := range t.gsoToWrite { + _, err = t.tunFile.Write(buffers[bufferIndex][offset:]) if err != nil { return err } diff --git a/tun_linux_offload.go b/tun_linux_offload.go index d99a8f2..930b939 100644 --- a/tun_linux_offload.go +++ b/tun_linux_offload.go @@ -750,8 +750,12 @@ func checksumNoFold(b []byte, initial uint64) uint64 { } func checksumFold(b []byte, initial uint64) uint16 { - r := clashtcpip.Checksum(uint32(initial), b) - return binary.BigEndian.Uint16(r[:]) + ac := checksumNoFold(b, initial) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + return uint16(ac) } func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {