Skip to content

Commit

Permalink
Fix crash in client and server with invalid handshake packets
Browse files Browse the repository at this point in the history
In both Dial() and Listen(), if the peer sends a V5 handshake packet
without the extension field, the library crashes. This patch fixes the
issue and adds unit tests.
  • Loading branch information
aler9 committed Jul 3, 2024
1 parent 79d9961 commit 9cc5f1c
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 0 deletions.
8 changes: 8 additions & 0 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,14 @@ func (dl *dialer) handleHandshake(p packet.Packet) {
sendTsbpdDelay := uint16(dl.config.PeerLatency.Milliseconds())

if cif.Version == 5 {
if cif.SRTHS == nil {
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("missing handshake extension"),
}
return
}

// Check if the peer version is sufficient
if cif.SRTHS.SRTVersion < dl.config.MinVersion {
dl.sendShutdown(cif.SRTSocketId)
Expand Down
77 changes: 77 additions & 0 deletions dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,80 @@ func TestDialV5(t *testing.T) {
pc.Close()
ln.Close()
}

func TestDialV5MissingExtension(t *testing.T) {
ln, err := net.ListenPacket("udp", "127.0.0.1:6003")
require.NoError(t, err)
defer ln.Close()

go func() {
// read induction request
buf := make([]byte, MAX_MSS_SIZE)
n, addr, err := ln.ReadFrom(buf)
require.NoError(t, err)
p, err := packet.NewPacketFromData(addr, buf[:n])
require.NoError(t, err)
recvcif := &packet.CIFHandshake{}
err = p.UnmarshalCIF(recvcif)
require.NoError(t, err)
require.Equal(t, packet.HSTYPE_INDUCTION, recvcif.HandshakeType)

// write induction response
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = 0
p.Header().DestinationSocketId = recvcif.SRTSocketId
sendcif := &packet.CIFHandshake{
IsRequest: false,
Version: 5,
EncryptionField: 0,
ExtensionField: 0x4A17,
InitialPacketSequenceNumber: recvcif.InitialPacketSequenceNumber,
MaxTransmissionUnitSize: recvcif.MaxTransmissionUnitSize,
MaxFlowWindowSize: recvcif.MaxFlowWindowSize,
HandshakeType: packet.HSTYPE_INDUCTION,
SRTSocketId: recvcif.SRTSocketId,
SynCookie: 1234,
}
sendcif.PeerIP.FromNetAddr(ln.LocalAddr())
p.MarshalCIF(sendcif)
var outbuf bytes.Buffer
err = p.Marshal(&outbuf)
require.NoError(t, err)
ln.WriteTo(outbuf.Bytes(), p.Header().Addr)

// read conclusion request
n, addr, err = ln.ReadFrom(buf)
require.NoError(t, err)
p, err = packet.NewPacketFromData(addr, buf[:n])
require.NoError(t, err)
recvcif = &packet.CIFHandshake{}
err = p.UnmarshalCIF(recvcif)
require.NoError(t, err)
require.Equal(t, packet.HSTYPE_CONCLUSION, recvcif.HandshakeType)

// write invalid conclusion response
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = 0
p.Header().DestinationSocketId = recvcif.SRTSocketId
sendcif = recvcif
sendcif.IsRequest = false
sendcif.SRTSocketId = 9876
sendcif.SynCookie = 0
sendcif.PeerIP.FromNetAddr(ln.LocalAddr())
sendcif.HasHS = false
p.MarshalCIF(sendcif)
outbuf.Reset()
err = p.Marshal(&outbuf)
require.NoError(t, err)
ln.WriteTo(outbuf.Bytes(), p.Header().Addr)
}()

_, err = Dial("srt", "127.0.0.1:6003", DefaultConfig())
require.EqualError(t, err, "missing handshake extension")
}
10 changes: 10 additions & 0 deletions listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,16 @@ func (ln *listener) handleHandshake(p packet.Packet) {
return
}
} else if cif.Version == 5 {
if cif.SRTHS == nil {
cif.HandshakeType = packet.HandshakeType(REJ_ROGUE)
ln.log("handshake:recv:error", func() string { return "missing handshake extension" })
p.MarshalCIF(cif)
ln.log("handshake:send:dump", func() string { return p.Dump() })
ln.log("handshake:send:cif", func() string { return cif.String() })
ln.send(p)
return
}

// Check if the peer version is sufficient
if cif.SRTHS.SRTVersion < config.MinVersion {
cif.HandshakeType = packet.HandshakeType(REJ_VERSION)
Expand Down
95 changes: 95 additions & 0 deletions listen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/datarhei/gosrt/circular"
"github.com/datarhei/gosrt/packet"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -435,3 +436,97 @@ func TestListenAsync(t *testing.T) {
ln.Close()
listenerWg.Wait()
}

func TestListenHSV5MissingExtension(t *testing.T) {
ln, err := Listen("srt", "127.0.0.1:6003", DefaultConfig())
require.NoError(t, err)

listenDone := make(chan struct{})
defer func() { <-listenDone }()

go func() {
defer close(listenDone)
for {
_, _, err := ln.Accept(func(req ConnRequest) ConnType {
return SUBSCRIBE
})
if err != nil {
break
}
}
}()

conn, err := net.Dial("udp", "127.0.0.1:6003")
require.NoError(t, err)
defer conn.Close()

// send induction request
p := packet.NewPacket(conn.RemoteAddr())
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = 0
p.Header().DestinationSocketId = 0
sendcif := &packet.CIFHandshake{
IsRequest: true,
Version: 4,
EncryptionField: 0,
ExtensionField: 2,
InitialPacketSequenceNumber: circular.New(10000, packet.MAX_SEQUENCENUMBER),
MaxTransmissionUnitSize: MAX_MSS_SIZE,
MaxFlowWindowSize: 25600,
HandshakeType: packet.HSTYPE_INDUCTION,
SRTSocketId: 55555,
SynCookie: 0,
}
sendcif.PeerIP.FromNetAddr(conn.LocalAddr())
p.MarshalCIF(sendcif)
var buf bytes.Buffer
err = p.Marshal(&buf)
require.NoError(t, err)
_, err = conn.Write(buf.Bytes())
require.NoError(t, err)

// read induction response
inbuf := make([]byte, MAX_MSS_SIZE)
n, err := conn.Read(inbuf)
require.NoError(t, err)
p, err = packet.NewPacketFromData(conn.RemoteAddr(), inbuf[:n])
require.NoError(t, err)
recvcif := &packet.CIFHandshake{}
err = p.UnmarshalCIF(recvcif)
require.NoError(t, err)

// send conclusion
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = 0
p.Header().DestinationSocketId = 0 // recvcif.SRTSocketId
sendcif.Version = 5
sendcif.ExtensionField = recvcif.ExtensionField
sendcif.HandshakeType = packet.HSTYPE_CONCLUSION
sendcif.SynCookie = recvcif.SynCookie
sendcif.HasSID = true
sendcif.StreamId = "foobar"
p.MarshalCIF(sendcif)
buf.Reset()
err = p.Marshal(&buf)
require.NoError(t, err)
_, err = conn.Write(buf.Bytes())
require.NoError(t, err)

// read error
n, err = conn.Read(inbuf)
require.NoError(t, err)
p, err = packet.NewPacketFromData(conn.RemoteAddr(), inbuf[:n])
require.NoError(t, err)
recvcif = &packet.CIFHandshake{}
err = p.UnmarshalCIF(recvcif)
require.NoError(t, err)
require.Equal(t, recvcif.HandshakeType, packet.HandshakeType(REJ_ROGUE))

ln.Close()
}

0 comments on commit 9cc5f1c

Please sign in to comment.