Skip to content

Commit

Permalink
webrtc: wait for fin_ack for closing datachannel
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Dec 5, 2023
1 parent 97b4ca0 commit e178585
Show file tree
Hide file tree
Showing 13 changed files with 618 additions and 230 deletions.
7 changes: 7 additions & 0 deletions core/network/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ type MuxedStream interface {
SetWriteDeadline(time.Time) error
}

// AsyncCloser is implemented by streams that need to do expensive operations on close before
// releasing the resources. Closing the stream async avoids blocking the calling goroutine.
type AsyncCloser interface {
// AsyncClose closes the stream and executes onDone after the stream is closed
AsyncClose(onDone func()) error
}

// MuxedConn represents a connection to a remote peer that has been
// extended to support stream multiplexing.
//
Expand Down
6 changes: 6 additions & 0 deletions p2p/net/swarm/swarm_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ func (s *Stream) Write(p []byte) (int, error) {
// Close closes the stream, closing both ends and freeing all associated
// resources.
func (s *Stream) Close() error {
if as, ok := s.stream.(network.AsyncCloser); ok {
err := as.AsyncClose(func() {
s.closeAndRemoveStream()
})
return err
}
err := s.stream.Close()
s.closeAndRemoveStream()
return err
Expand Down
45 changes: 45 additions & 0 deletions p2p/net/swarm/swarm_stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package swarm

import (
"context"
"sync/atomic"
"testing"

"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/stretchr/testify/require"
)

type asyncStreamWrapper struct {
network.MuxedStream
beforeClose func()
}

func (s *asyncStreamWrapper) AsyncClose(onDone func()) error {
s.beforeClose()
err := s.Close()
onDone()
return err
}

func TestStreamAsyncCloser(t *testing.T) {
s1 := makeSwarm(t)
s2 := makeSwarm(t)

s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL)
s, err := s1.NewStream(context.Background(), s2.LocalPeer())
require.NoError(t, err)
ss, ok := s.(*Stream)
require.True(t, ok)

var called atomic.Bool
as := &asyncStreamWrapper{
MuxedStream: ss.stream,
beforeClose: func() {
called.Store(true)
},
}
ss.stream = as
ss.Close()
require.True(t, called.Load())
}
83 changes: 83 additions & 0 deletions p2p/test/swarm/swarm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package swarm_test

import (
"context"
"fmt"
"io"
"sync"
"testing"
Expand All @@ -14,6 +15,7 @@ import (
rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -243,3 +245,84 @@ func TestLimitStreamsWhenHangingHandlers(t *testing.T) {
return false
}, 5*time.Second, 100*time.Millisecond)
}

func TestLimitStreamsWhenHangingHandlersWebRTC(t *testing.T) {
var partial rcmgr.PartialLimitConfig
const streamLimit = 10
partial.System.Streams = streamLimit
mgr, err := rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(partial.Build(rcmgr.InfiniteLimits)))
require.NoError(t, err)

maddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/webrtc-direct")
require.NoError(t, err)

receiver, err := libp2p.New(
libp2p.ResourceManager(mgr),
libp2p.ListenAddrs(maddr),
libp2p.Transport(libp2pwebrtc.New),
)
require.NoError(t, err)
t.Cleanup(func() { receiver.Close() })

var wg sync.WaitGroup
wg.Add(1)

const pid = "/test"
receiver.SetStreamHandler(pid, func(s network.Stream) {
defer s.Close()
s.Write([]byte{42})
wg.Wait()
})

// Open streamLimit streams
success := 0
// we make a lot of tries because identify and identify push take up a few streams
for i := 0; i < 1000 && success < streamLimit; i++ {
mgr, err = rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(rcmgr.InfiniteLimits))
require.NoError(t, err)

sender, err := libp2p.New(libp2p.ResourceManager(mgr), libp2p.Transport(libp2pwebrtc.New))
require.NoError(t, err)
t.Cleanup(func() { sender.Close() })

sender.Peerstore().AddAddrs(receiver.ID(), receiver.Addrs(), peerstore.PermanentAddrTTL)

s, err := sender.NewStream(context.Background(), receiver.ID(), pid)
if err != nil {
continue
}

var b [1]byte
_, err = io.ReadFull(s, b[:])
if err == nil {
success++
}
sender.Close()
}
require.Equal(t, streamLimit, success)
// We have the maximum number of streams open. Next call should fail.
mgr, err = rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(rcmgr.InfiniteLimits))
require.NoError(t, err)

sender, err := libp2p.New(libp2p.ResourceManager(mgr), libp2p.Transport(libp2pwebrtc.New))
require.NoError(t, err)
t.Cleanup(func() { sender.Close() })

sender.Peerstore().AddAddrs(receiver.ID(), receiver.Addrs(), peerstore.PermanentAddrTTL)

_, err = sender.NewStream(context.Background(), receiver.ID(), pid)
require.Error(t, err)
// Close the open streams
wg.Done()

// Next call should succeed
require.Eventually(t, func() bool {
s, err := sender.NewStream(context.Background(), receiver.ID(), pid)
if err == nil {
s.Close()
return true
}
fmt.Println(err)
return false
}, 5*time.Second, 1*time.Second)
}
3 changes: 0 additions & 3 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,6 @@ func TestMoreStreamsThanOurLimits(t *testing.T) {
const streamCount = 1024
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
if strings.Contains(tc.Name, "WebRTC") {
t.Skip("This test potentially exhausts the uint16 WebRTC stream ID space.")
}
listenerLimits := rcmgr.PartialLimitConfig{
PeerDefault: rcmgr.ResourceLimits{
Streams: 32,
Expand Down
110 changes: 48 additions & 62 deletions p2p/transport/webrtc/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ import (
"context"
"errors"
"fmt"
"math"
"net"
"sync"
"sync/atomic"

ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
Expand All @@ -25,9 +23,7 @@ import (

var _ tpt.CapableConn = &connection{}

const maxAcceptQueueLen = 10

const maxDataChannelID = 1 << 10
const maxAcceptQueueLen = 256

type errConnectionTimeout struct{}

Expand All @@ -47,7 +43,8 @@ type connection struct {
transport *WebRTCTransport
scope network.ConnManagementScope

closeErr error
closeOnce sync.Once
closeErr error

localPeer peer.ID
localMultiaddr ma.Multiaddr
Expand All @@ -56,9 +53,8 @@ type connection struct {
remoteKey ic.PubKey
remoteMultiaddr ma.Multiaddr

m sync.Mutex
streams map[uint16]*stream
nextStreamID atomic.Int32
m sync.Mutex
streams map[uint16]*stream

acceptQueue chan dataChannel

Expand Down Expand Up @@ -97,25 +93,12 @@ func newConnection(

acceptQueue: make(chan dataChannel, maxAcceptQueueLen),
}
switch direction {
case network.DirInbound:
c.nextStreamID.Store(1)
case network.DirOutbound:
// stream ID 0 is used for the Noise handshake stream
c.nextStreamID.Store(2)
}

pc.OnConnectionStateChange(c.onConnectionStateChange)
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
if c.IsClosed() {
return
}
// Limit the number of streams, since we're not able to actually properly close them.
// See https://github.com/libp2p/specs/issues/575 for details.
if *dc.ID() > maxDataChannelID {
c.Close()
return
}
dc.OnOpen(func() {
rwc, err := dc.Detach()
if err != nil {
Expand All @@ -133,7 +116,6 @@ func newConnection(
}
})
})

return c, nil
}

Expand All @@ -144,16 +126,41 @@ func (c *connection) ConnState() network.ConnectionState {

// Close closes the underlying peerconnection.
func (c *connection) Close() error {
if c.IsClosed() {
return nil
}
c.closeOnce.Do(func() {
c.closeErr = errors.New("connection closed")
// cancel must be called after closeErr is set. This ensures interested goroutines waiting on
// ctx.Done can read closeErr without holding the conn lock.
c.cancel()
c.m.Lock()
streams := c.streams
c.streams = nil
c.m.Unlock()
for _, str := range streams {
str.Reset()
}
c.pc.Close()
c.scope.Done()
})
return nil
}

c.m.Lock()
defer c.m.Unlock()
c.scope.Done()
c.closeErr = errors.New("connection closed")
c.cancel()
return c.pc.Close()
func (c *connection) closeTimedOut() error {
c.closeOnce.Do(func() {
c.closeErr = errConnectionTimeout{}
// cancel must be called after closeErr is set. This ensures interested goroutines waiting on
// ctx.Done can read closeErr without holding the conn lock.
c.cancel()
c.m.Lock()
streams := c.streams
c.streams = nil
c.m.Unlock()
for _, str := range streams {
str.closeWithError(errConnectionTimeout{})
}
c.pc.Close()
c.scope.Done()
})
return nil
}

func (c *connection) IsClosed() bool {
Expand All @@ -170,29 +177,18 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error
return nil, c.closeErr
}

id := c.nextStreamID.Add(2) - 2
if id > math.MaxUint16 {
return nil, errors.New("exhausted stream ID space")
}
// Limit the number of streams, since we're not able to actually properly close them.
// See https://github.com/libp2p/specs/issues/575 for details.
if id > maxDataChannelID {
c.Close()
return c.OpenStream(ctx)
}

streamID := uint16(id)
dc, err := c.pc.CreateDataChannel("", &webrtc.DataChannelInit{ID: &streamID})
dc, err := c.pc.CreateDataChannel("", nil)
if err != nil {
return nil, err
}
rwc, err := c.detachChannel(ctx, dc)
if err != nil {
return nil, fmt.Errorf("open stream: %w", err)
}
str := newStream(dc, rwc, func() { c.removeStream(streamID) })
fmt.Println("opened dc with ID: ", *dc.ID())
str := newStream(dc, rwc, func() { c.removeStream(*dc.ID()) })
if err := c.addStream(str); err != nil {
str.Close()
str.Reset()
return nil, err
}
return str, nil
Expand All @@ -205,7 +201,7 @@ func (c *connection) AcceptStream() (network.MuxedStream, error) {
case dc := <-c.acceptQueue:
str := newStream(dc.channel, dc.stream, func() { c.removeStream(*dc.channel.ID()) })
if err := c.addStream(str); err != nil {
str.Close()
str.Reset()
return nil, err
}
return str, nil
Expand All @@ -223,6 +219,9 @@ func (c *connection) Transport() tpt.Transport { return c.transport }
func (c *connection) addStream(str *stream) error {
c.m.Lock()
defer c.m.Unlock()
if c.IsClosed() {
return fmt.Errorf("connection closed: %w", c.closeErr)
}
if _, ok := c.streams[str.id]; ok {
return errors.New("stream ID already exists")
}
Expand All @@ -238,20 +237,7 @@ func (c *connection) removeStream(id uint16) {

func (c *connection) onConnectionStateChange(state webrtc.PeerConnectionState) {
if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed {
// reset any streams
if c.IsClosed() {
return
}
c.m.Lock()
defer c.m.Unlock()
c.closeErr = errConnectionTimeout{}
for k, str := range c.streams {
str.setCloseError(c.closeErr)
delete(c.streams, k)
}
c.cancel()
c.scope.Done()
c.pc.Close()
c.closeTimedOut()
}
}

Expand Down
Loading

0 comments on commit e178585

Please sign in to comment.