Skip to content

Commit

Permalink
webrtc: fix race in TestMuxedConnection (#2607)
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt authored Oct 18, 2023
1 parent 99f7611 commit 2b57e26
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions p2p/transport/webrtc/udpmux/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -187,41 +186,57 @@ func TestMuxedConnection(t *testing.T) {
connCount := 3

ufrags := []string{"a", "b", "c"}
var mu sync.Mutex
addrUfragMap := make(map[string]string)
ufragConnsMap := make(map[string][]net.PacketConn)
for _, ufrag := range ufrags {
for i := 0; i < connCount; i++ {
cc := newPacketConn(t)
addrUfragMap[cc.LocalAddr().String()] = ufrag
ufragConnsMap[ufrag] = append(ufragConnsMap[ufrag], cc)
}
}

done := make(chan bool, len(ufrags))
for _, ufrag := range ufrags {
go func(ufrag string) {
for i := 0; i < connCount; i++ {
cc := newPacketConn(t)
mu.Lock()
addrUfragMap[cc.LocalAddr().String()] = ufrag
mu.Unlock()
for _, cc := range ufragConnsMap[ufrag] {
setupMapping(t, ufrag, cc, m)
for j := 0; j < msgCount; j++ {
cc.WriteTo([]byte(ufrag), c.LocalAddr())
}
}
done <- true
}(ufrag)
}
for i := 0; i < len(ufrags); i++ {
<-done
}

for _, ufrag := range ufrags {
mc, err := m.GetConn(ufrag, c.LocalAddr()) // the address is irrelevant
require.NoError(t, err)
msgs := 0
stunRequests := 0
msg := make([]byte, 1500)
addrPacketCount := make(map[string]int)
for i := 0; i < connCount; i++ {
msg := make([]byte, 100)
// Read the binding request
_, addr1, err := mc.ReadFrom(msg)
require.NoError(t, err)
require.Equal(t, addrUfragMap[addr1.String()], ufrag)
// Read individual msgs
for i := 0; i < msgCount; i++ {
n, addr2, err := mc.ReadFrom(msg)
for j := 0; j < msgCount+1; j++ {
n, addr1, err := mc.ReadFrom(msg)
require.NoError(t, err)
require.Equal(t, addr2, addr1)
require.Equal(t, ufrag, string(msg[:n]))
require.Equal(t, addrUfragMap[addr1.String()], ufrag)
addrPacketCount[addr1.String()]++
if stun.IsMessage(msg[:n]) {
stunRequests++
} else {
msgs++
}
}
delete(addrUfragMap, addr1.String())
}
for addr, v := range addrPacketCount {
require.Equal(t, v, msgCount+1) // msgCount msgs + 1 STUN binding request
delete(addrUfragMap, addr)
}
require.Equal(t, len(addrPacketCount), connCount)
}
require.Equal(t, len(addrUfragMap), 0)
}

0 comments on commit 2b57e26

Please sign in to comment.