Skip to content

Commit

Permalink
block repeated handshake requests
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 committed Jul 3, 2024
1 parent ac83781 commit 96fc2cf
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 8 deletions.
24 changes: 20 additions & 4 deletions conn_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func newConnRequest(ln *listener, p packet.Packet) *connRequest {
return nil
}

c := &connRequest{
req := &connRequest{
ln: ln,
addr: p.Header().Addr,
start: time.Now(),
Expand All @@ -229,10 +229,22 @@ func newConnRequest(ln *listener, p packet.Packet) *connRequest {
return nil
}

c.crypto = cr
req.crypto = cr
}

return c
ln.lock.Lock()
_, exists := ln.connReqs[cif.SRTSocketId]
if !exists {
ln.connReqs[cif.SRTSocketId] = req
}
ln.lock.Unlock()

// we received a duplicate request: reject silently
if exists {
return nil
}

return req
} else {
if cif.HandshakeType.IsRejection() {
ln.log("handshake:recv:error", func() string { return fmt.Sprintf("connection rejected: %s", cif.HandshakeType.String()) })
Expand Down Expand Up @@ -282,6 +294,10 @@ func (req *connRequest) SetRejectionReason(reason RejectionReason) {
}

func (req *connRequest) Reject(reason RejectionReason) {
req.ln.lock.Lock()
delete(req.ln.connReqs, req.socketId)
req.ln.lock.Unlock()

p := packet.NewPacket(req.addr)
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
Expand Down Expand Up @@ -375,9 +391,9 @@ func (req *connRequest) Accept() (Conn, error) {
req.ln.log("handshake:send:cif", func() string { return req.handshake.String() })
req.ln.send(p)

// Add the connection to the list of known connections
req.ln.lock.Lock()
req.ln.conns[socketId] = conn
delete(req.ln.connReqs, req.socketId)
req.ln.lock.Unlock()

return conn, nil
Expand Down
8 changes: 5 additions & 3 deletions listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,10 @@ type listener struct {

config Config

backlog chan packet.Packet
conns map[uint32]*srtConn
lock sync.RWMutex
backlog chan packet.Packet
connReqs map[uint32]*connRequest
conns map[uint32]*srtConn
lock sync.RWMutex

start time.Time

Expand Down Expand Up @@ -189,6 +190,7 @@ func Listen(network, address string, config Config) (Listener, error) {
return nil, fmt.Errorf("listen: no local address")
}

ln.connReqs = make(map[uint32]*connRequest)
ln.conns = make(map[uint32]*srtConn)

ln.backlog = make(chan packet.Packet, 128)
Expand Down
114 changes: 113 additions & 1 deletion listen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/datarhei/gosrt/circular"
"github.com/datarhei/gosrt/packet"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -441,6 +442,7 @@ func TestListenParallelRequests(t *testing.T) {
require.NoError(t, err)

listenDone := make(chan struct{})
defer func() { <-listenDone }()

var reqReady sync.WaitGroup
reqReady.Add(4)
Expand Down Expand Up @@ -495,5 +497,115 @@ func TestListenParallelRequests(t *testing.T) {
clientSideConnReady.Wait()

ln.Close()
<-listenDone
}

func TestListenDiscardRepeatedHandshakes(t *testing.T) {
ln, err := Listen("srt", "127.0.0.1:6003", DefaultConfig())
require.NoError(t, err)

listenDone := make(chan struct{})
defer func() { <-listenDone }()

singleReqReceived := make(chan struct{})

go func() {
defer close(listenDone)

var onlyRequest ConnRequest

for {
req, err := ln.Accept2()
if err != nil {
break
}

close(singleReqReceived)
onlyRequest = req
}

onlyRequest.Reject(REJ_CLOSE)
}()

for i := 0; i < 4; i++ {
conn, err := net.Dial("udp", "127.0.0.1:6003")
require.NoError(t, err)
defer conn.Close()

// send induction request
p := packet.NewPacket(conn.RemoteAddr())
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = 0
p.Header().DestinationSocketId = 0
sendcif := &packet.CIFHandshake{
IsRequest: true,
Version: 4,
EncryptionField: 0,
ExtensionField: 2,
InitialPacketSequenceNumber: circular.New(10000, packet.MAX_SEQUENCENUMBER),
MaxTransmissionUnitSize: MAX_MSS_SIZE,
MaxFlowWindowSize: 25600,
HandshakeType: packet.HSTYPE_INDUCTION,
SRTSocketId: 55555,
SynCookie: 0,
}
sendcif.PeerIP.FromNetAddr(conn.LocalAddr())
p.MarshalCIF(sendcif)
var buf bytes.Buffer
err = p.Marshal(&buf)
require.NoError(t, err)
_, err = conn.Write(buf.Bytes())
require.NoError(t, err)

// read induction response
inbuf := make([]byte, 1024)
n, err := conn.Read(inbuf)
require.NoError(t, err)
p, err = packet.NewPacketFromData(conn.RemoteAddr(), inbuf[:n])
require.NoError(t, err)
recvcif := &packet.CIFHandshake{}
err = p.UnmarshalCIF(recvcif)
require.NoError(t, err)

// send conclusion
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = 0
p.Header().DestinationSocketId = 0 // recvcif.SRTSocketId
sendcif.Version = 5
sendcif.ExtensionField = recvcif.ExtensionField
sendcif.HandshakeType = packet.HSTYPE_CONCLUSION
sendcif.SynCookie = recvcif.SynCookie
sendcif.HasHS = true
sendcif.SRTHS = &packet.CIFHandshakeExtension{
SRTVersion: SRT_VERSION,
SRTFlags: packet.CIFHandshakeExtensionFlags{
TSBPDSND: true,
TSBPDRCV: true,
CRYPT: true, // must always set to true
TLPKTDROP: true,
PERIODICNAK: true,
REXMITFLG: true,
STREAM: false,
PACKET_FILTER: false,
},
RecvTSBPDDelay: uint16(120),
SendTSBPDDelay: uint16(120),
}
sendcif.HasSID = true
sendcif.StreamId = "foobar"
p.MarshalCIF(sendcif)
buf.Reset()
err = p.Marshal(&buf)
require.NoError(t, err)
_, err = conn.Write(buf.Bytes())
require.NoError(t, err)
}

<-singleReqReceived
ln.Close()
}

0 comments on commit 96fc2cf

Please sign in to comment.