Skip to content

Commit

Permalink
code optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
xtaci committed Dec 9, 2024
1 parent a50b0bd commit f52841a
Showing 1 changed file with 71 additions and 73 deletions.
144 changes: 71 additions & 73 deletions hopper.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,63 +140,41 @@ func (l *Listener) Start() {

// packetIn processes incoming packets and forwards them to the next hop.
func (l *Listener) packetIn(data []byte, raddr net.Addr) {
// decrypt incoming packet if crypterIn is set
packetOk := false
if l.crypterIn != nil && len(data) >= headerSize {
l.crypterIn.Decrypt(data, data)
checksum := crc32.ChecksumIEEE(data[headerSize:])
if checksum != binary.LittleEndian.Uint32(data[checksumOffset:]) {
l.logger.Println("packetIn checksum mismatch")
return
}
data = data[headerSize:]
packetOk = true
} else if l.crypterIn == nil {
packetOk = true
// decrypt the packet if crypterIn is set
data, err := decryptPacket(l.crypterIn, data)
if err != nil {
l.logger.Println("decrypt error:", err)
return
}

if packetOk {
// encrypt or re-encrypt the packet if crypterOut is set(with new nonce)
if l.crypterOut != nil {
dataOut := make([]byte, len(data)+headerSize)
copy(dataOut[headerSize:], data)
// fill the nonce(12 bytes)
_, _ = io.ReadFull(rand.Reader, dataOut[nonceOffset:nonceOffset+nonceSize])
// fill the checksum(4 bytes)
checksum := crc32.ChecksumIEEE(data)
binary.LittleEndian.PutUint32(dataOut[checksumOffset:], checksum)
// encrypt the packet
l.crypterOut.Encrypt(dataOut, dataOut)
//fmt.Println(unsafe.Pointer(l), "encrypted listener out", string(dataOut))
data = dataOut
}

// load the connection from the incoming connections
l.incomingConnectionsLock.RLock()
conn, ok := l.incomingConnections[raddr.String()]
l.incomingConnectionsLock.RUnlock()

if ok { // existing connection
l.watcher.WriteTimeout(nil, conn, data, time.Now().Add(l.timeout))
} else { // new connection
// dial the next hop
conn, err := net.Dial("udp", l.nextHop)
if err != nil {
l.logger.Println("dial target error:", err)
return
}
// encrypt or re-encrypt the packet if crypterOut is set(with new nonce)
data = encryptPacket(l.crypterOut, data)

// add the connection to the incoming connections
l.addClient(raddr, conn)
// log new connection
l.logger.Printf("new connection from %s to %s", raddr.String(), l.nextHop)
// load the connection from the incoming connections
l.incomingConnectionsLock.RLock()
conn, ok := l.incomingConnections[raddr.String()]
l.incomingConnectionsLock.RUnlock()

// watch the connection
// the context is the address of incoming packet
ctx := raddr
l.watcher.ReadTimeout(ctx, conn, make([]byte, mtuLimit), time.Now().Add(l.timeout))
l.watcher.WriteTimeout(nil, conn, data, time.Now().Add(l.timeout)) // write needs not to specify the context(where the packet from)
if ok { // existing connection
l.watcher.WriteTimeout(nil, conn, data, time.Now().Add(l.timeout))
} else { // new connection
// dial the next hop
conn, err := net.Dial("udp", l.nextHop)
if err != nil {
l.logger.Println("dial target error:", err)
return
}

// add the connection to the incoming connections
l.addClient(raddr, conn)
// log new connection
l.logger.Printf("new connection from %s to %s", raddr.String(), l.nextHop)

// watch the connection
// the context is the address of incoming packet
ctx := raddr
l.watcher.ReadTimeout(ctx, conn, make([]byte, mtuLimit), time.Now().Add(l.timeout))
l.watcher.WriteTimeout(nil, conn, data, time.Now().Add(l.timeout)) // write needs not to specify the context(where the packet from)
}
}

Expand Down Expand Up @@ -231,30 +209,14 @@ func (l *Listener) switcher() {
dataFromProxy := res.Buffer[:res.Size]

// decrypt data from the proxy connection if crypterOut is set.
if l.crypterOut != nil {
l.crypterOut.Decrypt(dataFromProxy, dataFromProxy)
checksum := crc32.ChecksumIEEE(dataFromProxy[headerSize:])
if checksum != binary.LittleEndian.Uint32(dataFromProxy[checksumOffset:]) {
l.logger.Println("crypterOut checksum mismatch")
continue
}
dataFromProxy = dataFromProxy[headerSize:]
//fmt.Println(unsafe.Pointer(l), "proxy crypterOut", string(dataFromProxy))
dataFromProxy, err := decryptPacket(l.crypterOut, dataFromProxy)
if err != nil {
l.logger.Println("decrypt error:", err)
continue
}

// re-encrypt data if crypterIn is set.
if l.crypterIn != nil {
data := make([]byte, len(dataFromProxy)+headerSize)
copy(data[headerSize:], dataFromProxy)
// fill the nonce(12 bytes)
_, _ = io.ReadFull(rand.Reader, data[nonceOffset:nonceOffset+nonceSize])
// fill the checksum(4 bytes)
checksum := crc32.ChecksumIEEE(dataFromProxy)
binary.LittleEndian.PutUint32(data[checksumOffset:], checksum)
// encrypt the packet
l.crypterIn.Encrypt(data, data)
dataFromProxy = data
}
dataFromProxy = encryptPacket(l.crypterIn, dataFromProxy)

// forward the data to client via the listener.
l.conn.WriteTo(dataFromProxy, res.Context.(net.Addr))
Expand Down Expand Up @@ -289,3 +251,39 @@ func (l *Listener) Close() error {
})
return nil
}

// decryptPacket decrypts the packet using the provided crypter.
// It returns the decrypted data or an error if the checksum does not match.
func decryptPacket(crypter BlockCrypt, packet []byte) (data []byte, err error) {
if crypter != nil && len(packet) >= headerSize {
crypter.Decrypt(packet, packet)
checksum := crc32.ChecksumIEEE(packet[headerSize:])
if checksum != binary.LittleEndian.Uint32(packet[checksumOffset:]) {
return nil, errors.New("checksum mismatch")
}
data = packet[headerSize:]
} else if crypter == nil {
data = packet
}

return data, nil
}

// encryptPacket encrypts the packet using the provided crypter.
// It returns the encrypted data or the original data if no crypter is provided.
func encryptPacket(crypter BlockCrypt, data []byte) (packet []byte) {
if crypter != nil {
packet = make([]byte, len(data)+headerSize)
copy(packet[headerSize:], data)
// fill the nonce(12 bytes)
_, _ = io.ReadFull(rand.Reader, packet[nonceOffset:nonceOffset+nonceSize])
// fill the checksum(4 bytes)
checksum := crc32.ChecksumIEEE(data)
binary.LittleEndian.PutUint32(packet[checksumOffset:], checksum)
// encrypt the packet
crypter.Encrypt(packet, packet)
} else {
packet = data
}
return
}

0 comments on commit f52841a

Please sign in to comment.