diff --git a/multicast/peer.go b/multicast/peer.go index 6af3fcda..7ea87059 100644 --- a/multicast/peer.go +++ b/multicast/peer.go @@ -554,10 +554,15 @@ func (p *UDPPeer) unblockIPv4(multicastIP, sourceIP netip.Addr) (err error) { func (p *UDPPeer) unblockIPv6(multicastIP, sourceIP netip.Addr) (err error) { panic("IPv6 multicast peer not yet supported") } + func (p *UDPPeer) Read(b []byte) (int, netip.AddrPort, error) { return p.socket.RecvFrom(b, 0) } +func (p *UDPPeer) SetAsyncReadBuffer(to []byte) { + p.read.b = to +} + func (p *UDPPeer) AsyncRead(b []byte, fn func(error, int, netip.AddrPort)) { p.read.b = b p.read.fn = fn diff --git a/multicast/peer_ipv4_test.go b/multicast/peer_ipv4_test.go index 17015df2..b74ecda9 100644 --- a/multicast/peer_ipv4_test.go +++ b/multicast/peer_ipv4_test.go @@ -6,6 +6,7 @@ import ( "log" "net" "net/netip" + "sort" "sync" "sync/atomic" "testing" @@ -1974,3 +1975,124 @@ func TestUDPPeerIPv4_ReaderWriter(t *testing.T) { } } } + +func TestUDPPeerIPv4_MultipleReadersSameBuffer(t *testing.T) { + ioc := sonic.MustIO() + defer ioc.Close() + + var ( + ips = []string{"224.0.0.19", "224.0.0.20"} + ports = []int{1234, 4321} + addrs []netip.AddrPort + ) + for i := 0; i < 2; i++ { + addr, err := netip.ParseAddrPort( + fmt.Sprintf("%s:%d", ips[i], ports[i])) + if err != nil { + t.Fatal(err) + } + addrs = append(addrs, addr) + } + + var ( + chunk1, chunk2 [4]byte + b []byte + parity = 0 + readers []*UDPPeer + + read []int + ) + + for i := 0; i < 2; i++ { + r, err := NewUDPPeer(ioc, "udp", addrs[i].String()) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if err := r.Join(IP(ips[i])); err != nil { + t.Fatalf("reader could not join %s", ips[i]) + } else { + log.Printf("reader joined group %s", ips[i]) + } + + id := i + + var fn func(error, int, netip.AddrPort) + fn = func(err error, _ int, _ netip.AddrPort) { + if err != nil { + t.Fatal(err) + } + + v := binary.BigEndian.Uint32(b) + log.Printf( + "reader %d read %d from %p", + id, + v, + b, + ) + read = append(read, int(v)) + + parity++ + if parity%2 == 0 { + b = chunk1[:] + } else { + b = chunk2[:] + } + + for _, reader := range readers { + reader.SetAsyncReadBuffer(b) + } + r.AsyncRead(b, fn) + } + b = chunk1[:] + r.AsyncRead(b, fn) + + readers = append(readers, r) + } + + var writers []*UDPPeer + for i := 0; i < 2; i++ { + w, err := NewUDPPeer(ioc, "udp", "") + if err != nil { + t.Fatal(err) + } + defer w.Close() + + writers = append(writers, w) + } + + var wb [4]byte + const Nops = 32 + + for i := 0; i < Nops; i++ { + time.Sleep(time.Millisecond) + + ix := i % 2 + binary.BigEndian.PutUint32(wb[:], uint32(i)) + + _, err := writers[ix].Write(wb[:], addrs[ix]) + if err != nil && err != sonicerrors.ErrWouldBlock { + t.Fatalf("on the %d write err=%v", i, err) + } + + for j := 0; j < Nops; j++ { + ioc.PollOne() + } + } + for j := 0; j < Nops; j++ { + ioc.PollOne() + } + + // assert + sort.Ints(read) + + if len(read) != Nops { + t.Fatal("did not read correctly") + } + for i := 0; i < Nops; i++ { + if read[i] != i { + t.Fatal("did not read correctly") + } + } +}