Skip to content

Commit

Permalink
pass the connection scope to the connection (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann authored Jul 9, 2022
1 parent 3b13c83 commit c7149b3
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 41 deletions.
15 changes: 5 additions & 10 deletions p2p/transport/webtransport/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ type conn struct {
local, remote ma.Multiaddr
privKey ic.PrivKey
remotePubKey ic.PubKey
scope network.ConnScope
}

func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey, remotePubKey ic.PubKey) (*conn, error) {
func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey, remotePubKey ic.PubKey, scope network.ConnScope) (*conn, error) {
localPeer, err := peer.IDFromPrivateKey(privKey)
if err != nil {
return nil, err
Expand All @@ -49,6 +50,7 @@ func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey,
remotePubKey: remotePubKey,
local: local,
remote: remote,
scope: scope,
}, nil
}

Expand Down Expand Up @@ -78,12 +80,5 @@ func (c *conn) RemotePeer() peer.ID { return c.remotePeer }
func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey }
func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.local }
func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remote }

func (c *conn) Scope() network.ConnScope {
// TODO implement me
panic("implement me")
}

func (c *conn) Transport() tpt.Transport {
return c.transport
}
func (c *conn) Scope() network.ConnScope { return c.scope }
func (c *conn) Transport() tpt.Transport { return c.transport }
29 changes: 16 additions & 13 deletions p2p/transport/webtransport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) {
}

// TODO: check ?type=multistream URL param
c, err := l.server.Upgrade(w, r)
sess, err := l.server.Upgrade(w, r)
if err != nil {
log.Debugw("upgrade failed", "error", err)
// TODO: think about the status code to use here
Expand All @@ -123,25 +123,32 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) {
return
}
ctx, cancel := context.WithTimeout(l.ctx, handshakeTimeout)
conn, err := l.handshake(ctx, c)
sconn, err := l.handshake(ctx, sess)
if err != nil {
cancel()
log.Debugw("handshake failed", "error", err)
c.Close()
sess.Close()
connScope.Done()
return
}
cancel()

if err := connScope.SetPeer(conn.RemotePeer()); err != nil {
log.Debugw("resource manager blocked incoming connection for peer", "peer", conn.RemotePeer(), "addr", r.RemoteAddr, "error", err)
conn.Close()
if err := connScope.SetPeer(sconn.RemotePeer()); err != nil {
log.Debugw("resource manager blocked incoming connection for peer", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err)
sess.Close()
connScope.Done()
return
}

c, err := newConn(l.transport, sess, sconn.LocalPrivateKey(), sconn.RemotePublicKey(), connScope)
if err != nil {
sess.Close()
connScope.Done()
return
}

// TODO: think about what happens when this channel fills up
l.queue <- conn
l.queue <- c
}

func (l *listener) Accept() (tpt.CapableConn, error) {
Expand All @@ -153,16 +160,12 @@ func (l *listener) Accept() (tpt.CapableConn, error) {
}
}

func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (tpt.CapableConn, error) {
func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (network.ConnSecurity, error) {
str, err := sess.AcceptStream(ctx)
if err != nil {
return nil, err
}
conn, err := l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "")
if err != nil {
return nil, err
}
return newConn(l.transport, sess, conn.LocalPrivateKey(), conn.RemotePublicKey())
return l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "")
}

func (l *listener) Addr() net.Addr {
Expand Down
45 changes: 29 additions & 16 deletions p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"fmt"
manet "github.com/multiformats/go-multiaddr/net"
"io"
"sync"
"time"
Expand All @@ -21,7 +22,6 @@ import (
"github.com/lucas-clemente/quic-go/http3"
"github.com/marten-seemann/webtransport-go"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multihash"
)

Expand Down Expand Up @@ -73,6 +73,15 @@ func New(key ic.PrivKey, rcmgr network.ResourceManager) (tpt.Transport, error) {
}

func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) {
_, addr, err := manet.DialArgs(raddr)
if err != nil {
return nil, err
}
certHashes, err := extractCertHashes(raddr)
if err != nil {
return nil, err
}

scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr)
if err != nil {
log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err)
Expand All @@ -84,32 +93,40 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
return nil, err
}

conn, err := t.dial(ctx, raddr, p)
sess, err := t.dial(ctx, addr)
if err != nil {
scope.Done()
return nil, err
}
return conn, nil
}

func (t *transport) dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) {
_, addr, err := manet.DialArgs(raddr)
sconn, err := t.upgrade(ctx, sess, p, certHashes)
if err != nil {
sess.Close()
scope.Done()
return nil, err
}
url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint)
certHashes, err := extractCertHashes(raddr)
c, err := newConn(t, sess, t.privKey, sconn.RemotePublicKey(), scope)
if err != nil {
sess.Close()
scope.Done()
return nil, err
}
rsp, wconn, err := t.dialer.Dial(ctx, url, nil)
return c, nil
}

func (t *transport) dial(ctx context.Context, addr string) (*webtransport.Session, error) {
url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint)
rsp, sess, err := t.dialer.Dial(ctx, url, nil)
if err != nil {
return nil, err
}
if rsp.StatusCode < 200 || rsp.StatusCode > 299 {
return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode)
}
str, err := wconn.OpenStreamSync(ctx)
return sess, err
}

func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (network.ConnSecurity, error) {
str, err := sess.OpenStreamSync(ctx)
if err != nil {
return nil, err
}
Expand All @@ -127,11 +144,7 @@ func (t *transport) dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
if err != nil {
return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
}
sconn, err := t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wsess: wconn}, p, msgBytes)
if err != nil {
return nil, err
}
return newConn(t, wconn, t.privKey, sconn.RemotePublicKey())
return t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wsess: sess}, p, msgBytes)
}

func (t *transport) checkEarlyData(b []byte) error {
Expand Down
15 changes: 13 additions & 2 deletions p2p/transport/webtransport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"net"
"testing"
"time"

libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport"

Expand Down Expand Up @@ -287,7 +288,17 @@ func TestResourceManagerListening(t *testing.T) {
// The handshake will complete, but the server will immediately close the connection.
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
_, err = conn.AcceptStream()
require.Error(t, err)
defer conn.Close()
done := make(chan struct{})
go func() {
defer close(done)
_, err = conn.AcceptStream()
require.Error(t, err)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
})
}

0 comments on commit c7149b3

Please sign in to comment.