From 7d5a5f83570faf2439a9db672e50d2150a5b0e39 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 22 Dec 2021 15:11:32 +0400 Subject: [PATCH] use the ResourceManager --- p2p/transport/tcp/tcp.go | 27 +++++++++++++-- p2p/transport/tcp/tcp_test.go | 64 ++++++++++++++++++++++++++++++++--- 2 files changed, 85 insertions(+), 6 deletions(-) diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index d21dcb0793..e9de3b345e 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -83,6 +83,10 @@ func (ll *tcpListener) Accept() (manet.Conn, error) { } tryLinger(c, ll.sec) tryKeepAlive(c, true) + // We're not calling OpenConnection in the resource manager here, + // since the manet.Conn doesn't allow us to save the scope. + // It's the caller's (usually the go-libp2p-transport-upgrader) responsibility + // to call the resource manager. return c, nil } @@ -94,6 +98,7 @@ func DisableReuseport() Option { return nil } } + func WithConnectionTimeout(d time.Duration) Option { return func(tr *TcpTransport) error { tr.connectTimeout = d @@ -113,6 +118,8 @@ type TcpTransport struct { // TCP connect timeout connectTimeout time.Duration + rcmgr network.ResourceManager + reuse rtpt.Transport } @@ -120,10 +127,14 @@ var _ transport.Transport = &TcpTransport{} // NewTCPTransport creates a tcp transport object that tracks dialers and listeners // created. It represents an entire TCP stack (though it might not necessarily be). -func NewTCPTransport(upgrader transport.Upgrader, opts ...Option) (*TcpTransport, error) { +func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) { + if rcmgr == nil { + rcmgr = network.NullResourceManager + } tr := &TcpTransport{ Upgrader: upgrader, connectTimeout: defaultConnectTimeout, // can be set by using the WithConnectionTimeout option + rcmgr: rcmgr, } for _, o := range opts { if err := o(tr); err != nil { @@ -158,8 +169,19 @@ func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Co // Dial dials the peer at the remote address. func (t *TcpTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) { + connScope, err := t.rcmgr.OpenConnection(network.DirOutbound, true) + if err != nil { + log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) + return nil, err + } + if err := connScope.SetPeer(p); err != nil { + log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err) + connScope.Done() + return nil, err + } conn, err := t.maDial(ctx, raddr) if err != nil { + connScope.Done() return nil, err } // Set linger to 0 so we never get stuck in the TIME-WAIT state. When @@ -169,13 +191,14 @@ func (t *TcpTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) tryKeepAlive(conn, true) c, err := newTracingConn(conn, true) if err != nil { + connScope.Done() return nil, err } direction := network.DirOutbound if ok, isClient, _ := network.GetSimultaneousConnect(ctx); ok && !isClient { direction = network.DirInbound } - return t.Upgrader.Upgrade(ctx, t, c, direction, p) + return t.Upgrader.Upgrade(ctx, t, c, direction, p, connScope) } // UseReuseport returns true if reuseport is enabled and available. diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index 21a12bd0fd..b650e9cbfa 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -1,8 +1,12 @@ package tcp import ( + "context" + "errors" "testing" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/sec" @@ -11,11 +15,13 @@ import ( csms "github.com/libp2p/go-conn-security-multistream" mplex "github.com/libp2p/go-libp2p-mplex" + mocknetwork "github.com/libp2p/go-libp2p-testing/mocks/network" ttransport "github.com/libp2p/go-libp2p-testing/suites/transport" tptu "github.com/libp2p/go-libp2p-transport-upgrader" ma "github.com/multiformats/go-multiaddr" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" ) @@ -26,11 +32,11 @@ func TestTcpTransport(t *testing.T) { ua, err := tptu.New(ia, new(mplex.Transport)) require.NoError(t, err) - ta, err := NewTCPTransport(ua) + ta, err := NewTCPTransport(ua, nil) require.NoError(t, err) ub, err := tptu.New(ib, new(mplex.Transport)) require.NoError(t, err) - tb, err := NewTCPTransport(ub) + tb, err := NewTCPTransport(ub, nil) require.NoError(t, err) zero := "/ip4/127.0.0.1/tcp/0" @@ -41,13 +47,63 @@ func TestTcpTransport(t *testing.T) { envReuseportVal = true } +func TestResourceManager(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + peerA, ia := makeInsecureMuxer(t) + _, ib := makeInsecureMuxer(t) + + ua, err := tptu.New(ia, new(mplex.Transport)) + require.NoError(t, err) + ta, err := NewTCPTransport(ua, nil) + require.NoError(t, err) + ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) + require.NoError(t, err) + defer ln.Close() + + ub, err := tptu.New(ib, new(mplex.Transport)) + require.NoError(t, err) + rcmgr := mocknetwork.NewMockResourceManager(ctrl) + tb, err := NewTCPTransport(ub, rcmgr) + require.NoError(t, err) + + t.Run("success", func(t *testing.T) { + scope := mocknetwork.NewMockConnManagementScope(ctrl) + rcmgr.EXPECT().OpenConnection(network.DirOutbound, true).Return(scope, nil) + scope.EXPECT().SetPeer(peerA) + scope.EXPECT().PeerScope().Return(network.NullScope).AnyTimes() // called by the upgrader + conn, err := tb.Dial(context.Background(), ln.Multiaddr(), peerA) + require.NoError(t, err) + scope.EXPECT().Done() + defer conn.Close() + }) + + t.Run("connection denied", func(t *testing.T) { + rerr := errors.New("nope") + rcmgr.EXPECT().OpenConnection(network.DirOutbound, true).Return(nil, rerr) + _, err = tb.Dial(context.Background(), ln.Multiaddr(), peerA) + require.ErrorIs(t, err, rerr) + }) + + t.Run("peer denied", func(t *testing.T) { + scope := mocknetwork.NewMockConnManagementScope(ctrl) + rcmgr.EXPECT().OpenConnection(network.DirOutbound, true).Return(scope, nil) + rerr := errors.New("nope") + scope.EXPECT().SetPeer(peerA).Return(rerr) + scope.EXPECT().Done() + _, err = tb.Dial(context.Background(), ln.Multiaddr(), peerA) + require.ErrorIs(t, err, rerr) + }) +} + func TestTcpTransportCantDialDNS(t *testing.T) { for i := 0; i < 2; i++ { dnsa, err := ma.NewMultiaddr("/dns4/example.com/tcp/1234") require.NoError(t, err) var u transport.Upgrader - tpt, err := NewTCPTransport(u) + tpt, err := NewTCPTransport(u, nil) require.NoError(t, err) if tpt.CanDial(dnsa) { @@ -65,7 +121,7 @@ func TestTcpTransportCantListenUtp(t *testing.T) { require.NoError(t, err) var u transport.Upgrader - tpt, err := NewTCPTransport(u) + tpt, err := NewTCPTransport(u, nil) require.NoError(t, err) _, err = tpt.Listen(utpa)