From 19ebf83692bf3ab10e7f2cd0126a6a3a54d382d0 Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Fri, 16 Sep 2022 16:57:22 -0700 Subject: [PATCH 01/13] Muxer selection in TLS handshake first cut --- core/network/conn.go | 3 ++ core/sec/insecure/insecure.go | 10 ++++- core/sec/insecure/insecure_test.go | 4 +- core/sec/security.go | 8 ++-- p2p/muxer/muxer-multistream/multistream.go | 12 ++++++ p2p/net/conn-security-multistream/ssms.go | 10 ++--- .../conn-security-multistream/ssms_test.go | 18 ++++---- p2p/net/connmgr/connmgr_test.go | 1 + p2p/net/mock/mock_conn.go | 5 +++ p2p/net/swarm/swarm_conn.go | 5 +++ p2p/net/upgrader/listener_test.go | 8 ++-- p2p/net/upgrader/upgrader.go | 43 +++++++++++++++++-- p2p/security/noise/benchmark_test.go | 6 ++- p2p/security/noise/session.go | 7 +++ p2p/security/noise/session_test.go | 2 +- p2p/security/noise/session_transport.go | 4 +- p2p/security/noise/transport.go | 4 +- p2p/security/noise/transport_test.go | 38 ++++++++-------- p2p/security/tls/cmd/tlsdiag/client.go | 4 +- p2p/security/tls/cmd/tlsdiag/server.go | 4 +- p2p/security/tls/conn.go | 5 +++ p2p/security/tls/transport.go | 24 ++++++++++- p2p/security/tls/transport_test.go | 32 ++++++++------ p2p/transport/quic/conn.go | 6 +++ 24 files changed, 191 insertions(+), 72 deletions(-) diff --git a/core/network/conn.go b/core/network/conn.go index 8554493e25..85a5cdd3e3 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -48,6 +48,9 @@ type ConnSecurity interface { // RemotePublicKey returns the public key of the remote peer. RemotePublicKey() ic.PubKey + + // Early data negotiated by the security protocol. Empty if not supported. + EarlyData() string } // ConnMultiaddrs is an interface mixin for connection types that provide multiaddr diff --git a/core/sec/insecure/insecure.go b/core/sec/insecure/insecure.go index 12bd1842b4..6bc7a57673 100644 --- a/core/sec/insecure/insecure.go +++ b/core/sec/insecure/insecure.go @@ -60,7 +60,7 @@ func (t *Transport) LocalPrivateKey() ci.PrivKey { // // SecureInbound may fail if the remote peer sends an ID and public key that are inconsistent // with each other, or if a network error occurs during the ID exchange. -func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { +func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { conn := &Conn{ Conn: insecure, local: t.id, @@ -87,7 +87,7 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer // SecureOutbound may fail if the remote peer sends an ID and public key that are inconsistent // with each other, or if the ID sent by the remote peer does not match the one dialed. It may // also fail if a network error occurs during the ID exchange. -func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { +func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { conn := &Conn{ Conn: insecure, local: t.id, @@ -230,5 +230,11 @@ func (ic *Conn) LocalPrivateKey() ci.PrivKey { return ic.localPrivKey } +// EarlyData returns the security protocol's early data negotiated by handshake. +// Returns (empty string, false) if early data is not supported. +func (ic *Conn) EarlyData() string { + return "" +} + var _ sec.SecureTransport = (*Transport)(nil) var _ sec.SecureConn = (*Conn)(nil) diff --git a/core/sec/insecure/insecure_test.go b/core/sec/insecure/insecure_test.go index a3ce8314f4..da16772be0 100644 --- a/core/sec/insecure/insecure_test.go +++ b/core/sec/insecure/insecure_test.go @@ -94,9 +94,9 @@ func connect(t *testing.T, clientTpt, serverTpt *Transport, clientExpectsID, ser done := make(chan struct{}) go func() { defer close(done) - clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, clientExpectsID) + clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, clientExpectsID, nil) }() - serverConn, serverErr = serverTpt.SecureInbound(context.TODO(), server, serverExpectsID) + serverConn, serverErr = serverTpt.SecureInbound(context.TODO(), server, serverExpectsID, nil) <-done return } diff --git a/core/sec/security.go b/core/sec/security.go index c192a56a91..b76e62b900 100644 --- a/core/sec/security.go +++ b/core/sec/security.go @@ -20,10 +20,10 @@ type SecureConn interface { type SecureTransport interface { // SecureInbound secures an inbound connection. // If p is empty, connections from any peer are accepted. - SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error) + SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (SecureConn, error) // SecureOutbound secures an outbound connection. - SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error) + SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (SecureConn, error) } // A SecureMuxer is a wrapper around SecureTransport which can select security protocols @@ -33,10 +33,10 @@ type SecureMuxer interface { // The returned boolean indicates whether the connection should be treated as a server // connection; in the case of SecureInbound it should always be true. // If p is empty, connections from any peer are accepted. - SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, bool, error) + SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (SecureConn, bool, error) // SecureOutbound secures an outbound connection. // The returned boolean indicates whether the connection should be treated as a server // connection due to simultaneous open. - SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, bool, error) + SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (SecureConn, bool, error) } diff --git a/p2p/muxer/muxer-multistream/multistream.go b/p2p/muxer/muxer-multistream/multistream.go index bf9d41630c..208775842f 100644 --- a/p2p/muxer/muxer-multistream/multistream.go +++ b/p2p/muxer/muxer-multistream/multistream.go @@ -52,12 +52,14 @@ func (t *Transport) NewConn(nc net.Conn, isServer bool, scope network.PeerScope) return nil, err } proto = selected + fmt.Println(">>>>>> Server selected muxer: ", proto) } else { selected, err := mss.SelectOneOf(t.OrderPreference, nc) if err != nil { return nil, err } proto = selected + fmt.Println(">>>>>> Client selected muxer: ", proto) } if t.NegotiateTimeout != 0 { @@ -66,6 +68,7 @@ func (t *Transport) NewConn(nc net.Conn, isServer bool, scope network.PeerScope) } } + fmt.Println(">>>>>> multistream muxer conn selected proto: %s", proto) tpt, ok := t.tpts[proto] if !ok { return nil, fmt.Errorf("selected protocol we don't have a transport for") @@ -73,3 +76,12 @@ func (t *Transport) NewConn(nc net.Conn, isServer bool, scope network.PeerScope) return tpt.NewConn(nc, isServer, scope) } + +func (t *Transport) GetTranspotKeys() []string { + return t.OrderPreference +} + +func (t *Transport) GetTranspotByKey(key string) (network.Multiplexer, bool) { + val, ok := t.tpts[key] + return val, ok +} diff --git a/p2p/net/conn-security-multistream/ssms.go b/p2p/net/conn-security-multistream/ssms.go index 595d8dfde6..a5e3f07968 100644 --- a/p2p/net/conn-security-multistream/ssms.go +++ b/p2p/net/conn-security-multistream/ssms.go @@ -40,18 +40,18 @@ func (sm *SSMuxer) AddTransport(path string, transport sec.SecureTransport) { // SecureInbound secures an inbound connection using this multistream // multiplexed stream security transport. -func (sm *SSMuxer) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) { +func (sm *SSMuxer) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, bool, error) { tpt, _, err := sm.selectProto(ctx, insecure, true) if err != nil { return nil, false, err } - sconn, err := tpt.SecureInbound(ctx, insecure, p) + sconn, err := tpt.SecureInbound(ctx, insecure, p, muxers) return sconn, true, err } // SecureOutbound secures an outbound connection using this multistream // multiplexed stream security transport. -func (sm *SSMuxer) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) { +func (sm *SSMuxer) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, bool, error) { tpt, server, err := sm.selectProto(ctx, insecure, false) if err != nil { return nil, false, err @@ -59,7 +59,7 @@ func (sm *SSMuxer) SecureOutbound(ctx context.Context, insecure net.Conn, p peer var sconn sec.SecureConn if server { - sconn, err = tpt.SecureInbound(ctx, insecure, p) + sconn, err = tpt.SecureInbound(ctx, insecure, p, muxers) if err != nil { return nil, false, fmt.Errorf("failed to secure inbound connection: %s", err) } @@ -70,7 +70,7 @@ func (sm *SSMuxer) SecureOutbound(ctx context.Context, insecure net.Conn, p peer return nil, false, fmt.Errorf("unexpected peer") } } else { - sconn, err = tpt.SecureOutbound(ctx, insecure, p) + sconn, err = tpt.SecureOutbound(ctx, insecure, p, muxers) } return sconn, server, err diff --git a/p2p/net/conn-security-multistream/ssms_test.go b/p2p/net/conn-security-multistream/ssms_test.go index 5aa5db352d..d113397778 100644 --- a/p2p/net/conn-security-multistream/ssms_test.go +++ b/p2p/net/conn-security-multistream/ssms_test.go @@ -28,13 +28,13 @@ type TransportAdapter struct { mux *SSMuxer } -func (sm *TransportAdapter) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - sconn, _, err := sm.mux.SecureInbound(ctx, insecure, p) +func (sm *TransportAdapter) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { + sconn, _, err := sm.mux.SecureInbound(ctx, insecure, p, muxers) return sconn, err } -func (sm *TransportAdapter) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - sconn, _, err := sm.mux.SecureOutbound(ctx, insecure, p) +func (sm *TransportAdapter) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { + sconn, _, err := sm.mux.SecureOutbound(ctx, insecure, p, muxers) return sconn, err } @@ -57,7 +57,7 @@ func TestCommonProto(t *testing.T) { go func() { conn, err := ln.Accept() require.NoError(t, err) - c, err := muxB.SecureInbound(context.Background(), conn, idA) + c, err := muxB.SecureInbound(context.Background(), conn, idA, nil) require.NoError(t, err) connChan <- c }() @@ -67,7 +67,7 @@ func TestCommonProto(t *testing.T) { cconn, err := net.Dial("tcp", ln.Addr().String()) require.NoError(t, err) - cc, err := muxA.SecureOutbound(context.Background(), cconn, idB) + cc, err := muxA.SecureOutbound(context.Background(), cconn, idB, nil) require.NoError(t, err) require.Equal(t, cc.LocalPeer(), idA) require.Equal(t, cc.RemotePeer(), idB) @@ -103,7 +103,7 @@ func TestNoCommonProto(t *testing.T) { go func() { defer wg.Done() defer a.Close() - _, _, err := at.SecureInbound(ctx, a, "") + _, _, err := at.SecureInbound(ctx, a, "", nil) if err == nil { t.Error("connection should have failed") } @@ -112,10 +112,12 @@ func TestNoCommonProto(t *testing.T) { go func() { defer wg.Done() defer b.Close() - _, _, err := bt.SecureOutbound(ctx, b, "peerA") + _, _, err := bt.SecureOutbound(ctx, b, "peerA", nil) if err == nil { t.Error("connection should have failed") } }() wg.Wait() } + +// >>>>>> TODO <<<<<< Add test for non empty muxers cases diff --git a/p2p/net/connmgr/connmgr_test.go b/p2p/net/connmgr/connmgr_test.go index 312bdc1f3b..c61c46fd05 100644 --- a/p2p/net/connmgr/connmgr_test.go +++ b/p2p/net/connmgr/connmgr_test.go @@ -806,6 +806,7 @@ func (m mockConn) ID() string { panic func (m mockConn) NewStream(ctx context.Context) (network.Stream, error) { panic("implement me") } func (m mockConn) GetStreams() []network.Stream { panic("implement me") } func (m mockConn) Scope() network.ConnScope { panic("implement me") } +func (m mockConn) EarlyData() string { return "" } func TestPeerInfoSorting(t *testing.T) { t.Run("starts with temporary connections", func(t *testing.T) { diff --git a/p2p/net/mock/mock_conn.go b/p2p/net/mock/mock_conn.go index 5664fabb61..35b5464fd2 100644 --- a/p2p/net/mock/mock_conn.go +++ b/p2p/net/mock/mock_conn.go @@ -178,6 +178,11 @@ func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } +// EarlyData from security protocol handshake. Empty if not supported. +func (c *conn) EarlyData() string { + return "" +} + // Stat returns metadata about the connection func (c *conn) Stat() network.ConnStats { return c.stat diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index 779ee37374..a63f5a1a36 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -178,6 +178,11 @@ func (c *Conn) RemotePublicKey() ic.PubKey { return c.conn.RemotePublicKey() } +// EarlyData is the security protocol's early data result. Empty of not supported. +func (c *Conn) EarlyData() string { + return c.conn.EarlyData() +} + // Stat returns metadata pertaining to this connection func (c *Conn) Stat() network.ConnStats { c.streams.Lock() diff --git a/p2p/net/upgrader/listener_test.go b/p2p/net/upgrader/listener_test.go index 82c3952ef8..523f712ed8 100644 --- a/p2p/net/upgrader/listener_test.go +++ b/p2p/net/upgrader/listener_test.go @@ -30,13 +30,13 @@ type MuxAdapter struct { var _ sec.SecureMuxer = &MuxAdapter{} -func (mux *MuxAdapter) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) { - sconn, err := mux.tpt.SecureInbound(ctx, insecure, p) +func (mux *MuxAdapter) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, bool, error) { + sconn, err := mux.tpt.SecureInbound(ctx, insecure, p, muxers) return sconn, true, err } -func (mux *MuxAdapter) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) { - sconn, err := mux.tpt.SecureOutbound(ctx, insecure, p) +func (mux *MuxAdapter) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, bool, error) { + sconn, err := mux.tpt.SecureOutbound(ctx, insecure, p, muxers) return sconn, false, err } diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index 58347865a5..58c904e50a 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -15,6 +15,7 @@ import ( "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/net/pnet" + msmux "github.com/libp2p/go-libp2p/p2p/muxer/muxer-multistream" manet "github.com/multiformats/go-multiaddr/net" ) @@ -78,6 +79,9 @@ type upgrader struct { var _ transport.Upgrader = &upgrader{} func New(secureMuxer sec.SecureMuxer, muxer network.Multiplexer, opts ...Option) (transport.Upgrader, error) { + + fmt.Printf(">>>>>> New upgrader with muxer type: %T\n", muxer) + u := &upgrader{ secure: secureMuxer, muxer: muxer, @@ -175,6 +179,7 @@ func (u *upgrader) upgrade(ctx context.Context, t transport.Transport, maconn ma sconn.Close() return nil, fmt.Errorf("failed to negotiate stream multiplexer: %s", err) } + fmt.Printf(">>>>>> upgrader got muxed connection from setupMuxer: %T\n", smconn) tc := &transportConn{ MuxedConn: smconn, @@ -188,14 +193,46 @@ func (u *upgrader) upgrade(ctx context.Context, t transport.Transport, maconn ma } func (u *upgrader) setupSecurity(ctx context.Context, conn net.Conn, p peer.ID, dir network.Direction) (sec.SecureConn, bool, error) { + // Add candidate muxers to the security layer handshake process to save + // muxer negotiation round trip if possible. + // TODO: explore if there is a way of extracting muxers other than type assertion. + muxers := []string{} + msmuxer, ok := u.muxer.(*msmux.Transport) + if ok { + muxers = msmuxer.GetTranspotKeys() + } + + // DEBUG + fmt.Println(">>>>>> Upgrader appending muxers to security proto: ", muxers) + if dir == network.DirInbound { - return u.secure.SecureInbound(ctx, conn, p) + return u.secure.SecureInbound(ctx, conn, p, muxers) } - return u.secure.SecureOutbound(ctx, conn, p) + return u.secure.SecureOutbound(ctx, conn, p, muxers) } -func (u *upgrader) setupMuxer(ctx context.Context, conn net.Conn, server bool, scope network.PeerScope) (network.MuxedConn, error) { +func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server bool, scope network.PeerScope) (network.MuxedConn, error) { // TODO: The muxer should take a context. + + //// + msmuxer, ok := u.muxer.(*msmux.Transport) + muxerSelected := conn.EarlyData() + + // Use muxer selected from security handshake if available. Otherwise fall back to multistream-selection. + fmt.Println(">>>>>> upgrader: muxer key from early data is: ", muxerSelected) + + if ok && len(muxerSelected) > 0 { + //if false && ok { + tpt, ok := msmuxer.GetTranspotByKey(muxerSelected) + if !ok { + return nil, fmt.Errorf("selected a muxer we don't have a transport for") + } + + fmt.Println(">>>>>> upgrader: muxerSetup Returning earlydata muxedConn") + return tpt.NewConn(conn, server, scope) + } + + //// done := make(chan struct{}) var smconn network.MuxedConn diff --git a/p2p/security/noise/benchmark_test.go b/p2p/security/noise/benchmark_test.go index 52454f5959..0778b8e623 100644 --- a/p2p/security/noise/benchmark_test.go +++ b/p2p/security/noise/benchmark_test.go @@ -81,10 +81,10 @@ func (b benchenv) connect(stopTimer bool) (*secureSession, *secureSession) { done := make(chan struct{}) go func() { defer close(done) - initSession, initErr = b.initTpt.SecureOutbound(context.TODO(), initConn, b.respTpt.localID) + initSession, initErr = b.initTpt.SecureOutbound(context.TODO(), initConn, b.respTpt.localID, nil) }() - respSession, respErr := b.respTpt.SecureInbound(context.TODO(), respConn, "") + respSession, respErr := b.respTpt.SecureInbound(context.TODO(), respConn, "", nil) <-done if initErr != nil { @@ -98,6 +98,8 @@ func (b benchenv) connect(stopTimer bool) (*secureSession, *secureSession) { return initSession.(*secureSession), respSession.(*secureSession) } +// >>>>>> TODO <<<<<< add test cases for non-null muxers. + func drain(r io.Reader, done chan<- error, writeTo io.Writer) { _, err := io.Copy(writeTo, r) done <- err diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index 5e3d0956cf..36230485bb 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -38,6 +38,9 @@ type secureSession struct { // noise prologue prologue []byte earlyDataHandler EarlyDataHandler + + // Early data derived from handshaking. It is empty if not supported. | ----------------------------------------------------------------------------------------------------------- + earlyData string } // newSecureSession creates a Noise session over the given insecureConn Conn, using @@ -106,6 +109,10 @@ func (s *secureSession) RemotePublicKey() crypto.PubKey { return s.remoteKey } +func (s *secureSession) EarlyData() string { + return s.earlyData +} + func (s *secureSession) SetDeadline(t time.Time) error { return s.insecureConn.SetDeadline(t) } diff --git a/p2p/security/noise/session_test.go b/p2p/security/noise/session_test.go index 85de01b2ba..f4b3cf31fc 100644 --- a/p2p/security/noise/session_test.go +++ b/p2p/security/noise/session_test.go @@ -20,7 +20,7 @@ func TestContextCancellationRespected(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := initTransport.SecureOutbound(ctx, init, respTransport.localID) + _, err := initTransport.SecureOutbound(ctx, init, respTransport.localID, nil) require.Error(t, err) require.Equal(t, ctx.Err(), err) } diff --git a/p2p/security/noise/session_transport.go b/p2p/security/noise/session_transport.go index 973f6facf2..6f274c9c23 100644 --- a/p2p/security/noise/session_transport.go +++ b/p2p/security/noise/session_transport.go @@ -51,7 +51,7 @@ type SessionTransport struct { // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. -func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { +func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) @@ -63,6 +63,6 @@ func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, } // SecureOutbound runs the Noise handshake as the initiator. -func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { +func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, true) } diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index bd66d0fdd1..4550950891 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -40,7 +40,7 @@ func New(privkey crypto.PrivKey) (*Transport, error) { // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. -func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { +func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { c, err := newSecureSession(t, ctx, insecure, p, nil, nil, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) @@ -52,7 +52,7 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer } // SecureOutbound runs the Noise handshake as the initiator. -func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { +func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { return newSecureSession(t, ctx, insecure, p, nil, nil, true) } diff --git a/p2p/security/noise/transport_test.go b/p2p/security/noise/transport_test.go index 108efec816..26fcf088e9 100644 --- a/p2p/security/noise/transport_test.go +++ b/p2p/security/noise/transport_test.go @@ -78,10 +78,10 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess done := make(chan struct{}) go func() { defer close(done) - initConn, initErr = initTransport.SecureOutbound(context.Background(), init, respTransport.localID) + initConn, initErr = initTransport.SecureOutbound(context.TODO(), init, respTransport.localID, nil) }() - respConn, respErr := respTransport.SecureInbound(context.Background(), resp, "") + respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "", nil) <-done if initErr != nil { @@ -106,7 +106,7 @@ func TestDeadlines(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - _, err := initTransport.SecureOutbound(ctx, init, respTransport.localID) + _, err := initTransport.SecureOutbound(ctx, init, respTransport.localID, nil) if err == nil { t.Fatalf("expected i/o timeout err; got: %s", err) } @@ -171,7 +171,7 @@ func TestPeerIDMatch(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID) + conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID, nil) assert.NoError(t, err) assert.Equal(t, conn.RemotePeer(), respTransport.localID) b := make([]byte, 6) @@ -180,7 +180,7 @@ func TestPeerIDMatch(t *testing.T) { assert.Equal(t, b, []byte("foobar")) }() - conn, err := respTransport.SecureInbound(context.Background(), resp, initTransport.localID) + conn, err := respTransport.SecureInbound(context.TODO(), resp, initTransport.localID, nil) require.NoError(t, err) require.Equal(t, conn.RemotePeer(), initTransport.localID) _, err = conn.Write([]byte("foobar")) @@ -194,11 +194,11 @@ func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) { errChan := make(chan error) go func() { - _, err := initTransport.SecureOutbound(context.Background(), init, "a-random-peer-id") + _, err := initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id", nil) errChan <- err }() - _, err := respTransport.SecureInbound(context.Background(), resp, "") + _, err := respTransport.SecureInbound(context.TODO(), resp, "", nil) require.Error(t, err) initErr := <-errChan @@ -214,13 +214,13 @@ func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID) + conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID, nil) assert.NoError(t, err) _, err = conn.Read([]byte{0}) assert.Error(t, err) }() - _, err := respTransport.SecureInbound(context.Background(), resp, "a-random-peer-id") + _, err := respTransport.SecureInbound(context.TODO(), resp, "a-random-peer-id", nil) require.Error(t, err, "expected responder to fail with peer ID mismatch error") <-done } @@ -387,7 +387,7 @@ func TestPrologueMatches(t *testing.T) { tpt, err := initTransport. WithSessionOptions(Prologue(commonPrologue)) require.NoError(t, err) - conn, err := tpt.SecureOutbound(context.Background(), initConn, respTransport.localID) + conn, err := tpt.SecureOutbound(context.TODO(), initConn, respTransport.localID, nil) require.NoError(t, err) defer conn.Close() }() @@ -395,7 +395,7 @@ func TestPrologueMatches(t *testing.T) { tpt, err := respTransport. WithSessionOptions(Prologue(commonPrologue)) require.NoError(t, err) - conn, err := tpt.SecureInbound(context.Background(), respConn, "") + conn, err := tpt.SecureInbound(context.TODO(), respConn, "", nil) require.NoError(t, err) defer conn.Close() <-done @@ -415,14 +415,14 @@ func TestPrologueDoesNotMatchFailsHandshake(t *testing.T) { tpt, err := initTransport. WithSessionOptions(Prologue(initPrologue)) require.NoError(t, err) - _, err = tpt.SecureOutbound(context.Background(), initConn, respTransport.localID) + _, err = tpt.SecureOutbound(context.TODO(), initConn, respTransport.localID, nil) require.Error(t, err) }() tpt, err := respTransport.WithSessionOptions(Prologue(respPrologue)) require.NoError(t, err) - _, err = tpt.SecureInbound(context.Background(), respConn, "") + _, err = tpt.SecureInbound(context.TODO(), respConn, "", nil) require.Error(t, err) <-done } @@ -467,11 +467,11 @@ func TestEarlyDataAccepted(t *testing.T) { errChan := make(chan error) go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "") + _, err := respTransport.SecureInbound(context.Background(), initConn, "", nil) errChan <- err }() - conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) + conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID, nil) require.NoError(t, err) defer conn.Close() @@ -495,11 +495,11 @@ func TestEarlyDataRejected(t *testing.T) { errChan := make(chan error) go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "") + _, err := respTransport.SecureInbound(context.Background(), initConn, "", nil) errChan <- err }() - _, err = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) + _, err = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID, nil) require.Error(t, err) select { @@ -522,11 +522,11 @@ func TestEarlyDataRejectedWithNoHandler(t *testing.T) { errChan := make(chan error) go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "") + _, err := respTransport.SecureInbound(context.Background(), initConn, "", nil) errChan <- err }() - _, err = initTransport.SecureOutbound(context.Background(), respConn, respTransport.localID) + _, err = initTransport.SecureOutbound(context.Background(), respConn, respTransport.localID, nil) require.Error(t, err) select { diff --git a/p2p/security/tls/cmd/tlsdiag/client.go b/p2p/security/tls/cmd/tlsdiag/client.go index 7f1a7efecd..db7ef31175 100644 --- a/p2p/security/tls/cmd/tlsdiag/client.go +++ b/p2p/security/tls/cmd/tlsdiag/client.go @@ -49,7 +49,7 @@ func StartClient() error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - sconn, err := tp.SecureOutbound(ctx, conn, peerID) + sconn, err := tp.SecureOutbound(ctx, conn, peerID, nil) if err != nil { return err } @@ -61,3 +61,5 @@ func StartClient() error { fmt.Printf("Received message from server: %s\n", string(data)) return nil } + +// >>>>>> TODO <<<<<< Add an early data test case here. diff --git a/p2p/security/tls/cmd/tlsdiag/server.go b/p2p/security/tls/cmd/tlsdiag/server.go index 16f51f4a1b..6b575e4f88 100644 --- a/p2p/security/tls/cmd/tlsdiag/server.go +++ b/p2p/security/tls/cmd/tlsdiag/server.go @@ -57,7 +57,7 @@ func StartServer() error { func handleConn(tp *libp2ptls.Transport, conn net.Conn) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - sconn, err := tp.SecureInbound(ctx, conn, "") + sconn, err := tp.SecureInbound(ctx, conn, "", nil) if err != nil { return err } @@ -66,3 +66,5 @@ func handleConn(tp *libp2ptls.Transport, conn net.Conn) error { fmt.Printf("Closing connection to %s\n", conn.RemoteAddr()) return sconn.Close() } + +// >>>>>> TODO <<<<<< Add early data diagcase. diff --git a/p2p/security/tls/conn.go b/p2p/security/tls/conn.go index 6353eac80b..dd7d1e2a52 100644 --- a/p2p/security/tls/conn.go +++ b/p2p/security/tls/conn.go @@ -16,6 +16,7 @@ type conn struct { remotePeer peer.ID remotePubKey ci.PubKey + earlyData string } var _ sec.SecureConn = &conn{} @@ -35,3 +36,7 @@ func (c *conn) RemotePeer() peer.ID { func (c *conn) RemotePublicKey() ci.PubKey { return c.remotePubKey } + +func (c *conn) EarlyData() string { + return c.earlyData +} diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index f6aa64f6ab..8acd619dd2 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -51,8 +51,11 @@ var _ sec.SecureTransport = &Transport{} // SecureInbound runs the TLS handshake as a server. // If p is empty, connections from any peer are accepted. -func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { +func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { config, keyCh := t.identity.ConfigForPeer(p) + // >>>>>> the above returned config consists of the tls nextprotocol fields that can be amended here. + config.NextProtos = append(muxers, config.NextProtos...) + fmt.Println(">>>>>> this is where TLS server is created and handshake carried out <<<<<<") cs, err := t.handshake(ctx, tls.Server(insecure, config), keyCh) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) @@ -71,8 +74,10 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer // application data immediately afterwards. // If the handshake fails, the server will close the connection. The client will // notice this after 1 RTT when calling Read. -func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { +func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { config, keyCh := t.identity.ConfigForPeer(p) + // >>>>>> the above returned config consists of the tls nextprotocol fields that can be amended here. + config.NextProtos = append(muxers, config.NextProtos...) cs, err := t.handshake(ctx, tls.Client(insecure, config), keyCh) if err != nil { insecure.Close() @@ -89,9 +94,14 @@ func (t *Transport) handshake(ctx context.Context, tlsConn *tls.Conn, keyCh <-ch } }() + // handshaking... + fmt.Printf(">>> TLS handshaking <<< \n") if err := tlsConn.HandshakeContext(ctx); err != nil { return nil, err } + /// + ac := tlsConn.ConnectionState().NegotiatedProtocol + fmt.Printf(">>>> TLS negotiated app protocol is %s", ac) // Should be ready by this point, don't block. var remotePubKey ci.PubKey @@ -111,11 +121,21 @@ func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.Se if err != nil { return nil, err } + + nextProto := tlsConn.ConnectionState().NegotiatedProtocol + if len(nextProto) > 0 && nextProto == "libp2p" { + nextProto = "" + } + + // here is where we can insert the NegotiatedProtocol data in te secureConn return value. + // fmt.Println(" >>>>>> Adopted next proto: ", nextProto) + return &conn{ Conn: tlsConn, localPeer: t.localPeer, privKey: t.privKey, remotePeer: remotePeerID, remotePubKey: remotePubKey, + earlyData: nextProto, }, nil } diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index ef0009ebd7..e7c0e63808 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -90,12 +90,12 @@ func TestHandshakeSucceeds(t *testing.T) { serverConnChan := make(chan sec.SecureConn) go func() { - serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", nil) require.NoError(t, err) serverConnChan <- serverConn }() - clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, nil) require.NoError(t, err) defer clientConn.Close() @@ -202,12 +202,12 @@ func TestHandshakeConnectionCancelations(t *testing.T) { errChan := make(chan error) go func() { - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", nil) errChan <- err }() ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID) + _, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID, nil) require.ErrorIs(t, err, context.Canceled) require.Error(t, <-errChan) }) @@ -219,7 +219,7 @@ func TestHandshakeConnectionCancelations(t *testing.T) { go func() { ctx, cancel := context.WithCancel(context.Background()) cancel() - conn, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "") + conn, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "", nil) // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, // and closes the underlying connection when that context is canceled. // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. @@ -228,7 +228,7 @@ func TestHandshakeConnectionCancelations(t *testing.T) { } errChan <- err }() - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, nil) require.Error(t, err) require.ErrorIs(t, <-errChan, context.Canceled) }) @@ -248,7 +248,7 @@ func TestPeerIDMismatch(t *testing.T) { errChan := make(chan error) go func() { - conn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + conn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", nil) // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, // and closes the underlying connection when that context is canceled. // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. @@ -260,7 +260,7 @@ func TestPeerIDMismatch(t *testing.T) { // dial, but expect the wrong peer ID thirdPartyID, _ := createPeer(t) - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, thirdPartyID) + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, thirdPartyID, nil) require.Error(t, err) require.Contains(t, err.Error(), "peer IDs don't match") @@ -281,11 +281,11 @@ func TestPeerIDMismatch(t *testing.T) { go func() { thirdPartyID, _ := createPeer(t) // expect the wrong peer ID - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, thirdPartyID) + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, thirdPartyID, nil) errChan <- err }() - conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, nil) require.NoError(t, err) _, err = conn.Read([]byte{0}) require.Error(t, err) @@ -525,11 +525,11 @@ func TestInvalidCerts(t *testing.T) { serverErrChan := make(chan error) go func() { - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", nil) serverErrChan <- err }() - conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, nil) require.NoError(t, err) clientErrChan := make(chan error) go func() { @@ -568,11 +568,11 @@ func TestInvalidCerts(t *testing.T) { errChan := make(chan error) go func() { - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", nil) errChan <- err }() - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, nil) require.Error(t, err) tr.checkErr(t, err) @@ -589,3 +589,7 @@ func TestInvalidCerts(t *testing.T) { }) } } + + + + // >>>>>> TODO <<<<<< Add more TLS test for early data case. diff --git a/p2p/transport/quic/conn.go b/p2p/transport/quic/conn.go index a012dea3d7..94aeec4c99 100644 --- a/p2p/transport/quic/conn.go +++ b/p2p/transport/quic/conn.go @@ -90,6 +90,12 @@ func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } +// EearlyData returns the early data negotiated from the security protocol +// handshake, empty if not supported. +func (c *conn) EarlyData() string { + return "" +} + // LocalMultiaddr returns the local Multiaddr associated func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.localMultiaddr From 37adb91e0af881e8ad9d4e1c8754a4a4681b3a36 Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Tue, 20 Sep 2022 15:23:58 -0700 Subject: [PATCH 02/13] Clean up some part of the code --- .../conn-security-multistream/ssms_test.go | 21 ++- p2p/net/upgrader/upgrader.go | 1 - p2p/security/tls/transport.go | 8 +- p2p/security/tls/transport_test.go | 126 ++++++++++-------- 4 files changed, 91 insertions(+), 65 deletions(-) diff --git a/p2p/net/conn-security-multistream/ssms_test.go b/p2p/net/conn-security-multistream/ssms_test.go index d113397778..469581d441 100644 --- a/p2p/net/conn-security-multistream/ssms_test.go +++ b/p2p/net/conn-security-multistream/ssms_test.go @@ -38,7 +38,13 @@ func (sm *TransportAdapter) SecureOutbound(ctx context.Context, insecure net.Con return sconn, err } -func TestCommonProto(t *testing.T) { +var clientMuxerList = [][]string{{}, {"muxer1/1.0.0", "muxer2/1.0.1"}, {"muxer1"}, {}} +var serverMuxerList = [][]string{{}, {"muxer2/1.0.1", "muxer1/1.0.0"}, {}, {"muxer1"}} +var insecureExpectedMuxers = []string{"", "", "", ""} + +const numMuxers = 4 + +func commonProto(t *testing.T, serverMuxers, clientMuxers []string, expectedMuxer string) { privA, idA := newPeer(t) privB, idB := newPeer(t) @@ -57,7 +63,7 @@ func TestCommonProto(t *testing.T) { go func() { conn, err := ln.Accept() require.NoError(t, err) - c, err := muxB.SecureInbound(context.Background(), conn, idA, nil) + c, err := muxB.SecureInbound(context.Background(), conn, idA, serverMuxers) require.NoError(t, err) connChan <- c }() @@ -67,7 +73,7 @@ func TestCommonProto(t *testing.T) { cconn, err := net.Dial("tcp", ln.Addr().String()) require.NoError(t, err) - cc, err := muxA.SecureOutbound(context.Background(), cconn, idB, nil) + cc, err := muxA.SecureOutbound(context.Background(), cconn, idB, clientMuxers) require.NoError(t, err) require.Equal(t, cc.LocalPeer(), idA) require.Equal(t, cc.RemotePeer(), idB) @@ -81,6 +87,13 @@ func TestCommonProto(t *testing.T) { b, err := io.ReadAll(sc) require.NoError(t, err) require.Equal(t, "foobar", string(b)) + require.Equal(t, expectedMuxer, cc.EarlyData()) +} + +func TestCommonProto(t *testing.T) { + for i := 0; i < numMuxers; i++ { + commonProto(t, serverMuxerList[i], clientMuxerList[i], insecureExpectedMuxers[i]) + } } func TestNoCommonProto(t *testing.T) { @@ -119,5 +132,3 @@ func TestNoCommonProto(t *testing.T) { }() wg.Wait() } - -// >>>>>> TODO <<<<<< Add test for non empty muxers cases diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index 58c904e50a..d8b5f1420e 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -228,7 +228,6 @@ func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server b return nil, fmt.Errorf("selected a muxer we don't have a transport for") } - fmt.Println(">>>>>> upgrader: muxerSetup Returning earlydata muxedConn") return tpt.NewConn(conn, server, scope) } diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index 8acd619dd2..c3b65661c9 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -53,9 +53,8 @@ var _ sec.SecureTransport = &Transport{} // If p is empty, connections from any peer are accepted. func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { config, keyCh := t.identity.ConfigForPeer(p) - // >>>>>> the above returned config consists of the tls nextprotocol fields that can be amended here. + // Prepend the prefered muxers list to TLS config. config.NextProtos = append(muxers, config.NextProtos...) - fmt.Println(">>>>>> this is where TLS server is created and handshake carried out <<<<<<") cs, err := t.handshake(ctx, tls.Server(insecure, config), keyCh) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) @@ -76,7 +75,7 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer // notice this after 1 RTT when calling Read. func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { config, keyCh := t.identity.ConfigForPeer(p) - // >>>>>> the above returned config consists of the tls nextprotocol fields that can be amended here. + // Prepend the prefered muxers list to TLS config. config.NextProtos = append(muxers, config.NextProtos...) cs, err := t.handshake(ctx, tls.Client(insecure, config), keyCh) if err != nil { @@ -95,13 +94,12 @@ func (t *Transport) handshake(ctx context.Context, tlsConn *tls.Conn, keyCh <-ch }() // handshaking... - fmt.Printf(">>> TLS handshaking <<< \n") if err := tlsConn.HandshakeContext(ctx); err != nil { return nil, err } /// ac := tlsConn.ConnectionState().NegotiatedProtocol - fmt.Printf(">>>> TLS negotiated app protocol is %s", ac) + fmt.Println(">>>>>> TLS negotiated app protocol is: ", ac) // Should be ready by this point, don't block. var remotePubKey ci.PubKey diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index e7c0e63808..ed513e5838 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -28,6 +28,12 @@ import ( "github.com/stretchr/testify/require" ) +var clientMuxerList = [][]string{{}, {"muxer1/1.0.0", "muxer2/1.0.1"}, {"muxer1"}, {}} +var serverMuxerList = [][]string{{}, {"muxer2/1.0.1", "muxer1/1.0.0"}, {}, {"muxer1"}} +var expectedMuxers = []string{"", "muxer2/1.0.1", "", ""} + +const numMuxers = 4 + func createPeer(t *testing.T) (peer.ID, ic.PrivKey) { var priv ic.PrivKey var err error @@ -84,18 +90,21 @@ func isWindowsTCPCloseError(err error) bool { func TestHandshakeSucceeds(t *testing.T) { clientID, clientKey := createPeer(t) serverID, serverKey := createPeer(t) + var expectedMuxer string + var clientMuxers []string + var serverMuxers []string handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport) { clientInsecureConn, serverInsecureConn := connect(t) serverConnChan := make(chan sec.SecureConn) go func() { - serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", nil) + serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", serverMuxers) require.NoError(t, err) serverConnChan <- serverConn }() - clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, nil) + clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, clientMuxers) require.NoError(t, err) defer clientConn.Close() @@ -115,6 +124,7 @@ func TestHandshakeSucceeds(t *testing.T) { require.Equal(t, serverConn.RemotePeer(), clientID) require.True(t, clientConn.RemotePublicKey().Equals(serverKey.GetPublic()), "server public key mismatch") require.True(t, serverConn.RemotePublicKey().Equals(clientKey.GetPublic()), "client public key mismatch") + require.Equal(t, clientConn.EarlyData(), expectedMuxer) // exchange some data _, err = serverConn.Write([]byte("foobar")) require.NoError(t, err) @@ -130,16 +140,21 @@ func TestHandshakeSucceeds(t *testing.T) { serverTransport, err := New(serverKey) require.NoError(t, err) - t.Run("standard TLS with extension not critical", func(t *testing.T) { - handshake(t, clientTransport, serverTransport) - }) + for i := 0; i < numMuxers; i++ { + expectedMuxer = expectedMuxers[i] + clientMuxers = clientMuxerList[i] + serverMuxers = serverMuxerList[i] + t.Run("standard TLS with extension not critical", func(t *testing.T) { + handshake(t, clientTransport, serverTransport) + }) - t.Run("standard TLS with extension critical", func(t *testing.T) { - extensionCritical = true - t.Cleanup(func() { extensionCritical = false }) + t.Run("standard TLS with extension critical", func(t *testing.T) { + extensionCritical = true + t.Cleanup(func() { extensionCritical = false }) - handshake(t, clientTransport, serverTransport) - }) + handshake(t, clientTransport, serverTransport) + }) + } // Use transports with custom TLS certificates @@ -163,16 +178,21 @@ func TestHandshakeSucceeds(t *testing.T) { serverTransport.identity, err = NewIdentity(serverKey, WithCertTemplate(serverCertTmpl)) require.NoError(t, err) - t.Run("custom TLS with extension not critical", func(t *testing.T) { - handshake(t, clientTransport, serverTransport) - }) + for i := 0; i < numMuxers; i++ { + expectedMuxer = expectedMuxers[i] + clientMuxers = clientMuxerList[i] + serverMuxers = serverMuxerList[i] + t.Run("custom TLS with extension not critical", func(t *testing.T) { + handshake(t, clientTransport, serverTransport) + }) - t.Run("custom TLS with extension critical", func(t *testing.T) { - extensionCritical = true - t.Cleanup(func() { extensionCritical = false }) + t.Run("custom TLS with extension critical", func(t *testing.T) { + extensionCritical = true + t.Cleanup(func() { extensionCritical = false }) - handshake(t, clientTransport, serverTransport) - }) + handshake(t, clientTransport, serverTransport) + }) + } } // crypto/tls' cancellation logic works by spinning up a separate Go routine that watches the ctx. @@ -197,41 +217,43 @@ func TestHandshakeConnectionCancelations(t *testing.T) { serverTransport, err := New(serverKey) require.NoError(t, err) - t.Run("cancel outgoing connection", func(t *testing.T) { - clientInsecureConn, serverInsecureConn := connect(t) - - errChan := make(chan error) - go func() { - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", nil) - errChan <- err - }() - ctx, cancel := context.WithCancel(context.Background()) - cancel() - _, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID, nil) - require.ErrorIs(t, err, context.Canceled) - require.Error(t, <-errChan) - }) - - t.Run("cancel incoming connection", func(t *testing.T) { - clientInsecureConn, serverInsecureConn := connect(t) + for i := 0; i < numMuxers; i++ { + t.Run("cancel outgoing connection", func(t *testing.T) { + clientInsecureConn, serverInsecureConn := connect(t) - errChan := make(chan error) - go func() { + errChan := make(chan error) + go func() { + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", serverMuxerList[i]) + errChan <- err + }() ctx, cancel := context.WithCancel(context.Background()) cancel() - conn, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "", nil) - // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, - // and closes the underlying connection when that context is canceled. - // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. - if err == nil { - _, err = conn.Read([]byte{0}) - } - errChan <- err - }() - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, nil) - require.Error(t, err) - require.ErrorIs(t, <-errChan, context.Canceled) - }) + _, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID, clientMuxerList[i]) + require.ErrorIs(t, err, context.Canceled) + require.Error(t, <-errChan) + }) + + t.Run("cancel incoming connection", func(t *testing.T) { + clientInsecureConn, serverInsecureConn := connect(t) + + errChan := make(chan error) + go func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + conn, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "", serverMuxerList[i]) + // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, + // and closes the underlying connection when that context is canceled. + // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. + if err == nil { + _, err = conn.Read([]byte{0}) + } + errChan <- err + }() + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, clientMuxerList[i]) + require.Error(t, err) + require.ErrorIs(t, <-errChan, context.Canceled) + }) + } } func TestPeerIDMismatch(t *testing.T) { @@ -589,7 +611,3 @@ func TestInvalidCerts(t *testing.T) { }) } } - - - - // >>>>>> TODO <<<<<< Add more TLS test for early data case. From 3d677015b21cf38ad5a0b8af148378cafb2ef383 Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Tue, 20 Sep 2022 19:00:57 -0700 Subject: [PATCH 03/13] Change earlydata to ConnectionState for security connection. --- core/network/conn.go | 9 ++++++++- core/sec/insecure/insecure.go | 8 ++++---- p2p/net/conn-security-multistream/ssms_test.go | 2 +- p2p/net/connmgr/connmgr_test.go | 2 +- p2p/net/mock/mock_conn.go | 6 +++--- p2p/net/swarm/swarm_conn.go | 7 ++++--- p2p/net/upgrader/upgrader.go | 2 +- p2p/security/noise/session.go | 9 +++++---- p2p/security/tls/conn.go | 11 ++++++----- p2p/security/tls/transport.go | 16 +++++++--------- p2p/security/tls/transport_test.go | 2 +- p2p/transport/quic/conn.go | 8 ++++---- 12 files changed, 45 insertions(+), 37 deletions(-) diff --git a/core/network/conn.go b/core/network/conn.go index 85a5cdd3e3..66e5ce7538 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -34,6 +34,13 @@ type Conn interface { GetStreams() []Stream } +// ConnectionState holds extra information releated to the ConnSecurity entity. +type ConnectionState struct { + // Early data result derived from security protocol handshake. + // For example, Noise handshake payload or TLS/ALPN negotiation. + EarlyData string +} + // ConnSecurity is the interface that one can mix into a connection interface to // give it the security methods. type ConnSecurity interface { @@ -50,7 +57,7 @@ type ConnSecurity interface { RemotePublicKey() ic.PubKey // Early data negotiated by the security protocol. Empty if not supported. - EarlyData() string + ConnState() ConnectionState } // ConnMultiaddrs is an interface mixin for connection types that provide multiaddr diff --git a/core/sec/insecure/insecure.go b/core/sec/insecure/insecure.go index 6bc7a57673..4888258439 100644 --- a/core/sec/insecure/insecure.go +++ b/core/sec/insecure/insecure.go @@ -10,6 +10,7 @@ import ( "net" ci "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/sec" pb "github.com/libp2p/go-libp2p/core/sec/insecure/pb" @@ -230,10 +231,9 @@ func (ic *Conn) LocalPrivateKey() ci.PrivKey { return ic.localPrivKey } -// EarlyData returns the security protocol's early data negotiated by handshake. -// Returns (empty string, false) if early data is not supported. -func (ic *Conn) EarlyData() string { - return "" +// ConnState returns the security connection's state information. +func (ic *Conn) ConnState() network.ConnectionState { + return network.ConnectionState{} } var _ sec.SecureTransport = (*Transport)(nil) diff --git a/p2p/net/conn-security-multistream/ssms_test.go b/p2p/net/conn-security-multistream/ssms_test.go index 469581d441..67ea53939d 100644 --- a/p2p/net/conn-security-multistream/ssms_test.go +++ b/p2p/net/conn-security-multistream/ssms_test.go @@ -87,7 +87,7 @@ func commonProto(t *testing.T, serverMuxers, clientMuxers []string, expectedMuxe b, err := io.ReadAll(sc) require.NoError(t, err) require.Equal(t, "foobar", string(b)) - require.Equal(t, expectedMuxer, cc.EarlyData()) + require.Equal(t, expectedMuxer, cc.ConnState().EarlyData) } func TestCommonProto(t *testing.T) { diff --git a/p2p/net/connmgr/connmgr_test.go b/p2p/net/connmgr/connmgr_test.go index c61c46fd05..2053e3e6f7 100644 --- a/p2p/net/connmgr/connmgr_test.go +++ b/p2p/net/connmgr/connmgr_test.go @@ -806,7 +806,7 @@ func (m mockConn) ID() string { panic func (m mockConn) NewStream(ctx context.Context) (network.Stream, error) { panic("implement me") } func (m mockConn) GetStreams() []network.Stream { panic("implement me") } func (m mockConn) Scope() network.ConnScope { panic("implement me") } -func (m mockConn) EarlyData() string { return "" } +func (m mockConn) ConnState() network.ConnectionState { return network.ConnectionState{} } func TestPeerInfoSorting(t *testing.T) { t.Run("starts with temporary connections", func(t *testing.T) { diff --git a/p2p/net/mock/mock_conn.go b/p2p/net/mock/mock_conn.go index 35b5464fd2..48015a4c61 100644 --- a/p2p/net/mock/mock_conn.go +++ b/p2p/net/mock/mock_conn.go @@ -178,9 +178,9 @@ func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } -// EarlyData from security protocol handshake. Empty if not supported. -func (c *conn) EarlyData() string { - return "" +// ConnState of security connection. Empty if not supported. +func (c *conn) ConnState() network.ConnectionState { + return network.ConnectionState{} } // Stat returns metadata about the connection diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index a63f5a1a36..4de2727f80 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -178,9 +178,10 @@ func (c *Conn) RemotePublicKey() ic.PubKey { return c.conn.RemotePublicKey() } -// EarlyData is the security protocol's early data result. Empty of not supported. -func (c *Conn) EarlyData() string { - return c.conn.EarlyData() +// ConnState is the security connection state. including early data result. +// Empty if not supported. +func (c *Conn) ConnState() network.ConnectionState { + return c.conn.ConnState() } // Stat returns metadata pertaining to this connection diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index d8b5f1420e..2c7a1d44c6 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -216,7 +216,7 @@ func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server b //// msmuxer, ok := u.muxer.(*msmux.Transport) - muxerSelected := conn.EarlyData() + muxerSelected := conn.ConnState().EarlyData // Use muxer selected from security handshake if available. Otherwise fall back to multistream-selection. fmt.Println(">>>>>> upgrader: muxer key from early data is: ", muxerSelected) diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index 36230485bb..a2ee724473 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -10,6 +10,7 @@ import ( "github.com/flynn/noise" "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" ) @@ -39,8 +40,8 @@ type secureSession struct { prologue []byte earlyDataHandler EarlyDataHandler - // Early data derived from handshaking. It is empty if not supported. | ----------------------------------------------------------------------------------------------------------- - earlyData string + // Early data derived from handshaking. It is empty if not supported. | ----------------------------------------------------------------------------------------------------------- + earlyData string } // newSecureSession creates a Noise session over the given insecureConn Conn, using @@ -109,8 +110,8 @@ func (s *secureSession) RemotePublicKey() crypto.PubKey { return s.remoteKey } -func (s *secureSession) EarlyData() string { - return s.earlyData +func (s *secureSession) ConnState() network.ConnectionState { + return network.ConnectionState{EarlyData: s.earlyData} } func (s *secureSession) SetDeadline(t time.Time) error { diff --git a/p2p/security/tls/conn.go b/p2p/security/tls/conn.go index dd7d1e2a52..3ebc7aefc5 100644 --- a/p2p/security/tls/conn.go +++ b/p2p/security/tls/conn.go @@ -4,6 +4,7 @@ import ( "crypto/tls" ci "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/sec" ) @@ -14,9 +15,9 @@ type conn struct { localPeer peer.ID privKey ci.PrivKey - remotePeer peer.ID - remotePubKey ci.PubKey - earlyData string + remotePeer peer.ID + remotePubKey ci.PubKey + connectionState network.ConnectionState } var _ sec.SecureConn = &conn{} @@ -37,6 +38,6 @@ func (c *conn) RemotePublicKey() ci.PubKey { return c.remotePubKey } -func (c *conn) EarlyData() string { - return c.earlyData +func (c *conn) ConnState() network.ConnectionState { + return c.connectionState } diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index c3b65661c9..5f59dba068 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -11,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p/core/canonicallog" ci "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/sec" @@ -125,15 +126,12 @@ func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.Se nextProto = "" } - // here is where we can insert the NegotiatedProtocol data in te secureConn return value. - // fmt.Println(" >>>>>> Adopted next proto: ", nextProto) - return &conn{ - Conn: tlsConn, - localPeer: t.localPeer, - privKey: t.privKey, - remotePeer: remotePeerID, - remotePubKey: remotePubKey, - earlyData: nextProto, + Conn: tlsConn, + localPeer: t.localPeer, + privKey: t.privKey, + remotePeer: remotePeerID, + remotePubKey: remotePubKey, + connectionState: network.ConnectionState{EarlyData: nextProto}, }, nil } diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index ed513e5838..8a12dc2bef 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -124,7 +124,7 @@ func TestHandshakeSucceeds(t *testing.T) { require.Equal(t, serverConn.RemotePeer(), clientID) require.True(t, clientConn.RemotePublicKey().Equals(serverKey.GetPublic()), "server public key mismatch") require.True(t, serverConn.RemotePublicKey().Equals(clientKey.GetPublic()), "client public key mismatch") - require.Equal(t, clientConn.EarlyData(), expectedMuxer) + require.Equal(t, clientConn.ConnState().EarlyData, expectedMuxer) // exchange some data _, err = serverConn.Write([]byte("foobar")) require.NoError(t, err) diff --git a/p2p/transport/quic/conn.go b/p2p/transport/quic/conn.go index 94aeec4c99..1af53f2abc 100644 --- a/p2p/transport/quic/conn.go +++ b/p2p/transport/quic/conn.go @@ -90,10 +90,10 @@ func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } -// EearlyData returns the early data negotiated from the security protocol -// handshake, empty if not supported. -func (c *conn) EarlyData() string { - return "" +// ConnState is the state of security connection. +// It is empty if not supported. +func (c *conn) ConnState() network.ConnectionState { + return network.ConnectionState{} } // LocalMultiaddr returns the local Multiaddr associated From 9a37304024063889e8706489f9d51103c44e905e Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Wed, 21 Sep 2022 10:44:08 -0700 Subject: [PATCH 04/13] resolve merging conflicts --- p2p/security/noise/session.go | 2 +- p2p/security/noise/transport_test.go | 189 +++++++++++++++--------- p2p/transport/webtransport/listener.go | 2 +- p2p/transport/webtransport/transport.go | 2 +- 4 files changed, 122 insertions(+), 73 deletions(-) diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index a2ee724473..83fa6a729b 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -40,7 +40,7 @@ type secureSession struct { prologue []byte earlyDataHandler EarlyDataHandler - // Early data derived from handshaking. It is empty if not supported. | ----------------------------------------------------------------------------------------------------------- + // Early data derived from handshaking. It is empty if not supported. earlyData string } diff --git a/p2p/security/noise/transport_test.go b/p2p/security/noise/transport_test.go index 550eb10464..ab69cc3d28 100644 --- a/p2p/security/noise/transport_test.go +++ b/p2p/security/noise/transport_test.go @@ -16,6 +16,7 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/sec" + "github.com/libp2p/go-libp2p/p2p/security/noise/pb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -78,10 +79,10 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess done := make(chan struct{}) go func() { defer close(done) - initConn, initErr = initTransport.SecureOutbound(context.TODO(), init, respTransport.localID, nil) + initConn, initErr = initTransport.SecureOutbound(context.Background(), init, respTransport.localID, nil) }() - respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "", nil) + respConn, respErr := respTransport.SecureInbound(context.Background(), resp, "", nil) <-done if initErr != nil { @@ -171,7 +172,7 @@ func TestPeerIDMatch(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID, nil) + conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID, nil) assert.NoError(t, err) assert.Equal(t, conn.RemotePeer(), respTransport.localID) b := make([]byte, 6) @@ -180,7 +181,7 @@ func TestPeerIDMatch(t *testing.T) { assert.Equal(t, b, []byte("foobar")) }() - conn, err := respTransport.SecureInbound(context.TODO(), resp, initTransport.localID, nil) + conn, err := respTransport.SecureInbound(context.Background(), resp, initTransport.localID, nil) require.NoError(t, err) require.Equal(t, conn.RemotePeer(), initTransport.localID) _, err = conn.Write([]byte("foobar")) @@ -194,11 +195,11 @@ func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) { errChan := make(chan error) go func() { - _, err := initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id", nil) + _, err := initTransport.SecureOutbound(context.Background(), init, "a-random-peer-id", nil) errChan <- err }() - _, err := respTransport.SecureInbound(context.TODO(), resp, "", nil) + _, err := respTransport.SecureInbound(context.Background(), resp, "") require.Error(t, err) initErr := <-errChan @@ -214,13 +215,13 @@ func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID, nil) + conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID) assert.NoError(t, err) _, err = conn.Read([]byte{0}) assert.Error(t, err) }() - _, err := respTransport.SecureInbound(context.TODO(), resp, "a-random-peer-id", nil) + _, err := respTransport.SecureInbound(context.Background(), resp, "a-random-peer-id") require.Error(t, err, "expected responder to fail with peer ID mismatch error") <-done } @@ -387,7 +388,7 @@ func TestPrologueMatches(t *testing.T) { tpt, err := initTransport. WithSessionOptions(Prologue(commonPrologue)) require.NoError(t, err) - conn, err := tpt.SecureOutbound(context.TODO(), initConn, respTransport.localID, nil) + conn, err := tpt.SecureOutbound(context.Background(), initConn, respTransport.localID) require.NoError(t, err) defer conn.Close() }() @@ -395,7 +396,7 @@ func TestPrologueMatches(t *testing.T) { tpt, err := respTransport. WithSessionOptions(Prologue(commonPrologue)) require.NoError(t, err) - conn, err := tpt.SecureInbound(context.TODO(), respConn, "", nil) + conn, err := tpt.SecureInbound(context.Background(), respConn, "") require.NoError(t, err) defer conn.Close() <-done @@ -415,106 +416,154 @@ func TestPrologueDoesNotMatchFailsHandshake(t *testing.T) { tpt, err := initTransport. WithSessionOptions(Prologue(initPrologue)) require.NoError(t, err) - _, err = tpt.SecureOutbound(context.TODO(), initConn, respTransport.localID, nil) + _, err = tpt.SecureOutbound(context.Background(), initConn, respTransport.localID) require.Error(t, err) }() tpt, err := respTransport.WithSessionOptions(Prologue(respPrologue)) require.NoError(t, err) - _, err = tpt.SecureInbound(context.TODO(), respConn, "", nil) + _, err = tpt.SecureInbound(context.Background(), respConn, "") require.Error(t, err) <-done } type earlyDataHandler struct { - send func(context.Context, net.Conn, peer.ID) []byte - received func(context.Context, net.Conn, []byte) error + send func(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions + received func(context.Context, net.Conn, *pb.NoiseExtensions) error } -func (e *earlyDataHandler) Send(ctx context.Context, conn net.Conn, id peer.ID) []byte { +func (e *earlyDataHandler) Send(ctx context.Context, conn net.Conn, id peer.ID) *pb.NoiseExtensions { if e.send == nil { return nil } return e.send(ctx, conn, id) } -func (e *earlyDataHandler) Received(ctx context.Context, conn net.Conn, data []byte) error { +func (e *earlyDataHandler) Received(ctx context.Context, conn net.Conn, ext *pb.NoiseExtensions) error { if e.received == nil { return nil } - return e.received(ctx, conn, data) + return e.received(ctx, conn, ext) } func TestEarlyDataAccepted(t *testing.T) { - var receivedEarlyData []byte - serverEDH := &earlyDataHandler{ - received: func(_ context.Context, _ net.Conn, data []byte) error { - receivedEarlyData = data + handshake := func(t *testing.T, client, server EarlyDataHandler) { + t.Helper() + initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client, nil)) + require.NoError(t, err) + tpt := newTestTransport(t, crypto.Ed25519, 2048) + respTransport, err := tpt.WithSessionOptions(EarlyData(nil, server)) + require.NoError(t, err) + + initConn, respConn := newConnPair(t) + + errChan := make(chan error) + go func() { + _, err := respTransport.SecureInbound(context.Background(), initConn, "") + errChan <- err + }() + + conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) + require.NoError(t, err) + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("timeout") + case err := <-errChan: + require.NoError(t, err) + } + defer conn.Close() + } + + var receivedExtensions *pb.NoiseExtensions + receivingEDH := &earlyDataHandler{ + received: func(_ context.Context, _ net.Conn, ext *pb.NoiseExtensions) error { + receivedExtensions = ext return nil }, } - clientEDH := &earlyDataHandler{ - send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, + sendingEDH := &earlyDataHandler{ + send: func(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions { + return &pb.NoiseExtensions{WebtransportCerthashes: [][]byte{[]byte("foobar")}} + }, } - initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) - require.NoError(t, err) - tpt := newTestTransport(t, crypto.Ed25519, 2048) - respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH)) - require.NoError(t, err) - - initConn, respConn := newConnPair(t) - errChan := make(chan error) - go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "", nil) - errChan <- err - }() - - conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID, nil) - require.NoError(t, err) - defer conn.Close() + t.Run("client sending", func(t *testing.T) { + handshake(t, sendingEDH, receivingEDH) + require.Equal(t, [][]byte{[]byte("foobar")}, receivedExtensions.WebtransportCerthashes) + receivedExtensions = nil + }) - require.Equal(t, []byte("foobar"), receivedEarlyData) + t.Run("server sending", func(t *testing.T) { + handshake(t, receivingEDH, sendingEDH) + require.Equal(t, [][]byte{[]byte("foobar")}, receivedExtensions.WebtransportCerthashes) + receivedExtensions = nil + }) } func TestEarlyDataRejected(t *testing.T) { - serverEDH := &earlyDataHandler{ - received: func(_ context.Context, _ net.Conn, data []byte) error { return errors.New("nope") }, - } - clientEDH := &earlyDataHandler{ - send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, + handshake := func(t *testing.T, client, server EarlyDataHandler) (clientErr, serverErr error) { + initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client, nil)) + require.NoError(t, err) + tpt := newTestTransport(t, crypto.Ed25519, 2048) + respTransport, err := tpt.WithSessionOptions(EarlyData(nil, server)) + require.NoError(t, err) + + initConn, respConn := newConnPair(t) + + errChan := make(chan error) + go func() { + _, err := respTransport.SecureInbound(context.Background(), initConn, "") + errChan <- err + }() + + // As early data is sent with the last handshake message, the handshake will appear + // to succeed for the client. + var conn sec.SecureConn + conn, clientErr = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) + if clientErr == nil { + _, clientErr = conn.Read([]byte{0}) + } + + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("timeout") + case err := <-errChan: + serverErr = err + } + return + } + + receivingEDH := &earlyDataHandler{ + received: func(context.Context, net.Conn, *pb.NoiseExtensions) error { return errors.New("nope") }, + } + sendingEDH := &earlyDataHandler{ + send: func(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions { + return &pb.NoiseExtensions{WebtransportCerthashes: [][]byte{[]byte("foobar")}} + }, } - initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) - require.NoError(t, err) - tpt := newTestTransport(t, crypto.Ed25519, 2048) - respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH)) - require.NoError(t, err) - initConn, respConn := newConnPair(t) + t.Run("client sending", func(t *testing.T) { + clientErr, serverErr := handshake(t, sendingEDH, receivingEDH) + require.Error(t, clientErr) + require.EqualError(t, serverErr, "nope") - errChan := make(chan error) - go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "", nil) - errChan <- err - }() + }) - _, err = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID, nil) - require.Error(t, err) - - select { - case <-time.After(500 * time.Millisecond): - t.Fatal("timeout") - case err := <-errChan: - require.EqualError(t, err, "nope") - } + t.Run("server sending", func(t *testing.T) { + clientErr, serverErr := handshake(t, receivingEDH, sendingEDH) + require.Error(t, serverErr) + require.EqualError(t, clientErr, "nope") + }) } -func TestEarlyDataAcceptedWithNoHandler(t *testing.T) { +func TestEarlyfffDataAcceptedWithNoHandler(t *testing.T) { clientEDH := &earlyDataHandler{ - send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, + send: func(ctx context.Context, conn net.Conn, id peer.ID) *pb.NoiseExtensions { + return &pb.NoiseExtensions{WebtransportCerthashes: [][]byte{[]byte("foobar")}} + }, } - initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) + initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH, nil)) require.NoError(t, err) respTransport := newTestTransport(t, crypto.Ed25519, 2048) @@ -522,11 +571,11 @@ func TestEarlyDataAcceptedWithNoHandler(t *testing.T) { errChan := make(chan error) go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "", nil) + _, err := respTransport.SecureInbound(context.Background(), initConn, "") errChan <- err }() - _, err = initTransport.SecureOutbound(context.Background(), respConn, respTransport.localID, nil) + conn, err := initTransport.SecureOutbound(context.Background(), respConn, respTransport.localID) require.NoError(t, err) defer conn.Close() diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index ca8fe1cf35..9ca2df5a2c 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -194,7 +194,7 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (* if err != nil { return nil, fmt.Errorf("failed to initialize Noise session: %w", err) } - c, err := n.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") + c, err := n.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "", nil) if err != nil { return nil, err } diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 03cacb133c..9f7022e7d5 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -212,7 +212,7 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p if err != nil { return nil, fmt.Errorf("failed to create Noise transport: %w", err) } - c, err := n.SecureOutbound(ctx, &webtransportStream{Stream: str, wsess: sess}, p) + c, err := n.SecureOutbound(ctx, &webtransportStream{Stream: str, wsess: sess}, p, nil) if err != nil { return nil, err } From 196d970ce4f7279881a62eaee79529cfe91ea0ae Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Wed, 21 Sep 2022 11:51:30 -0700 Subject: [PATCH 05/13] Add stubs for noise --- p2p/security/noise/session_transport.go | 4 +- p2p/security/noise/transport.go | 14 +---- p2p/security/noise/transport_test.go | 71 +++++-------------------- 3 files changed, 17 insertions(+), 72 deletions(-) diff --git a/p2p/security/noise/session_transport.go b/p2p/security/noise/session_transport.go index 7b468edbb4..295d9779ed 100644 --- a/p2p/security/noise/session_transport.go +++ b/p2p/security/noise/session_transport.go @@ -64,7 +64,7 @@ type SessionTransport struct { // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. -func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { +func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) @@ -76,6 +76,6 @@ func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, } // SecureOutbound runs the Noise handshake as the initiator. -func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { +func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, true) } diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index 9c6369f586..0fd20f4b66 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -40,13 +40,8 @@ func New(privkey crypto.PrivKey) (*Transport, error) { // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. -<<<<<<< HEAD -func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - c, err := newSecureSession(t, ctx, insecure, p, nil, nil, nil, false) -======= func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { - c, err := newSecureSession(t, ctx, insecure, p, nil, nil, false) ->>>>>>> origin/muxer-selection-optimize + c, err := newSecureSession(t, ctx, insecure, p, nil, nil, nil, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) if maErr == nil { @@ -57,13 +52,8 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer } // SecureOutbound runs the Noise handshake as the initiator. -<<<<<<< HEAD -func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - return newSecureSession(t, ctx, insecure, p, nil, nil, nil, true) -======= func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { - return newSecureSession(t, ctx, insecure, p, nil, nil, true) ->>>>>>> origin/muxer-selection-optimize + return newSecureSession(t, ctx, insecure, p, nil, nil, nil, true) } func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTransport, error) { diff --git a/p2p/security/noise/transport_test.go b/p2p/security/noise/transport_test.go index 7faf672862..86d25da61b 100644 --- a/p2p/security/noise/transport_test.go +++ b/p2p/security/noise/transport_test.go @@ -199,7 +199,7 @@ func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) { errChan <- err }() - _, err := respTransport.SecureInbound(context.Background(), resp, "") + _, err := respTransport.SecureInbound(context.Background(), resp, "", nil) require.Error(t, err) initErr := <-errChan @@ -215,13 +215,13 @@ func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID) + conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID, nil) assert.NoError(t, err) _, err = conn.Read([]byte{0}) assert.Error(t, err) }() - _, err := respTransport.SecureInbound(context.Background(), resp, "a-random-peer-id") + _, err := respTransport.SecureInbound(context.Background(), resp, "a-random-peer-id", nil) require.Error(t, err, "expected responder to fail with peer ID mismatch error") <-done } @@ -388,7 +388,7 @@ func TestPrologueMatches(t *testing.T) { tpt, err := initTransport. WithSessionOptions(Prologue(commonPrologue)) require.NoError(t, err) - conn, err := tpt.SecureOutbound(context.Background(), initConn, respTransport.localID) + conn, err := tpt.SecureOutbound(context.Background(), initConn, respTransport.localID, nil) require.NoError(t, err) defer conn.Close() }() @@ -396,7 +396,7 @@ func TestPrologueMatches(t *testing.T) { tpt, err := respTransport. WithSessionOptions(Prologue(commonPrologue)) require.NoError(t, err) - conn, err := tpt.SecureInbound(context.Background(), respConn, "") + conn, err := tpt.SecureInbound(context.Background(), respConn, "", nil) require.NoError(t, err) defer conn.Close() <-done @@ -416,14 +416,14 @@ func TestPrologueDoesNotMatchFailsHandshake(t *testing.T) { tpt, err := initTransport. WithSessionOptions(Prologue(initPrologue)) require.NoError(t, err) - _, err = tpt.SecureOutbound(context.Background(), initConn, respTransport.localID) + _, err = tpt.SecureOutbound(context.Background(), initConn, respTransport.localID, nil) require.Error(t, err) }() tpt, err := respTransport.WithSessionOptions(Prologue(respPrologue)) require.NoError(t, err) - _, err = tpt.SecureInbound(context.Background(), respConn, "") + _, err = tpt.SecureInbound(context.Background(), respConn, "", nil) require.Error(t, err) <-done } @@ -460,11 +460,11 @@ func TestEarlyDataAccepted(t *testing.T) { errChan := make(chan error) go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "") + _, err := respTransport.SecureInbound(context.Background(), initConn, "", nil) errChan <- err }() - conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) + conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID, nil) require.NoError(t, err) select { case <-time.After(500 * time.Millisecond): @@ -508,63 +508,19 @@ func TestEarlyDataRejected(t *testing.T) { tpt := newTestTransport(t, crypto.Ed25519, 2048) respTransport, err := tpt.WithSessionOptions(EarlyData(nil, server)) require.NoError(t, err) -<<<<<<< HEAD - - initConn, respConn := newConnPair(t) - - errChan := make(chan error) - go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "") - errChan <- err - }() - - // As early data is sent with the last handshake message, the handshake will appear - // to succeed for the client. - var conn sec.SecureConn - conn, clientErr = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) - if clientErr == nil { - _, clientErr = conn.Read([]byte{0}) - } - - select { - case <-time.After(500 * time.Millisecond): - t.Fatal("timeout") - case err := <-errChan: - serverErr = err - } - return - } - - receivingEDH := &earlyDataHandler{ - received: func(context.Context, net.Conn, *pb.NoiseExtensions) error { return errors.New("nope") }, - } - sendingEDH := &earlyDataHandler{ - send: func(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions { - return &pb.NoiseExtensions{WebtransportCerthashes: [][]byte{[]byte("foobar")}} - }, - } - - t.Run("client sending", func(t *testing.T) { - clientErr, serverErr := handshake(t, sendingEDH, receivingEDH) - require.Error(t, clientErr) - require.EqualError(t, serverErr, "nope") - - }) - -======= initConn, respConn := newConnPair(t) errChan := make(chan error) go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "") + _, err := respTransport.SecureInbound(context.Background(), initConn, "", nil) errChan <- err }() // As early data is sent with the last handshake message, the handshake will appear // to succeed for the client. var conn sec.SecureConn - conn, clientErr = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) + conn, clientErr = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID, nil) if clientErr == nil { _, clientErr = conn.Read([]byte{0}) } @@ -594,7 +550,6 @@ func TestEarlyDataRejected(t *testing.T) { }) ->>>>>>> origin/muxer-selection-optimize t.Run("server sending", func(t *testing.T) { clientErr, serverErr := handshake(t, receivingEDH, sendingEDH) require.Error(t, serverErr) @@ -616,11 +571,11 @@ func TestEarlyfffDataAcceptedWithNoHandler(t *testing.T) { errChan := make(chan error) go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "") + _, err := respTransport.SecureInbound(context.Background(), initConn, "", nil) errChan <- err }() - conn, err := initTransport.SecureOutbound(context.Background(), respConn, respTransport.localID) + conn, err := initTransport.SecureOutbound(context.Background(), respConn, respTransport.localID, nil) require.NoError(t, err) defer conn.Close() From ce8976cd8792fefb740df842ae28227d04c0639c Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Wed, 21 Sep 2022 12:21:07 -0700 Subject: [PATCH 06/13] clean up code --- p2p/muxer/muxer-multistream/multistream.go | 3 --- p2p/net/upgrader/upgrader.go | 8 -------- p2p/security/noise/benchmark_test.go | 2 -- p2p/security/tls/cmd/tlsdiag/client.go | 2 -- p2p/security/tls/cmd/tlsdiag/server.go | 2 -- p2p/security/tls/transport.go | 3 --- 6 files changed, 20 deletions(-) diff --git a/p2p/muxer/muxer-multistream/multistream.go b/p2p/muxer/muxer-multistream/multistream.go index 208775842f..29a1eadf25 100644 --- a/p2p/muxer/muxer-multistream/multistream.go +++ b/p2p/muxer/muxer-multistream/multistream.go @@ -52,14 +52,12 @@ func (t *Transport) NewConn(nc net.Conn, isServer bool, scope network.PeerScope) return nil, err } proto = selected - fmt.Println(">>>>>> Server selected muxer: ", proto) } else { selected, err := mss.SelectOneOf(t.OrderPreference, nc) if err != nil { return nil, err } proto = selected - fmt.Println(">>>>>> Client selected muxer: ", proto) } if t.NegotiateTimeout != 0 { @@ -68,7 +66,6 @@ func (t *Transport) NewConn(nc net.Conn, isServer bool, scope network.PeerScope) } } - fmt.Println(">>>>>> multistream muxer conn selected proto: %s", proto) tpt, ok := t.tpts[proto] if !ok { return nil, fmt.Errorf("selected protocol we don't have a transport for") diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index 2c7a1d44c6..f16bf95555 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -80,8 +80,6 @@ var _ transport.Upgrader = &upgrader{} func New(secureMuxer sec.SecureMuxer, muxer network.Multiplexer, opts ...Option) (transport.Upgrader, error) { - fmt.Printf(">>>>>> New upgrader with muxer type: %T\n", muxer) - u := &upgrader{ secure: secureMuxer, muxer: muxer, @@ -179,7 +177,6 @@ func (u *upgrader) upgrade(ctx context.Context, t transport.Transport, maconn ma sconn.Close() return nil, fmt.Errorf("failed to negotiate stream multiplexer: %s", err) } - fmt.Printf(">>>>>> upgrader got muxed connection from setupMuxer: %T\n", smconn) tc := &transportConn{ MuxedConn: smconn, @@ -202,9 +199,6 @@ func (u *upgrader) setupSecurity(ctx context.Context, conn net.Conn, p peer.ID, muxers = msmuxer.GetTranspotKeys() } - // DEBUG - fmt.Println(">>>>>> Upgrader appending muxers to security proto: ", muxers) - if dir == network.DirInbound { return u.secure.SecureInbound(ctx, conn, p, muxers) } @@ -219,8 +213,6 @@ func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server b muxerSelected := conn.ConnState().EarlyData // Use muxer selected from security handshake if available. Otherwise fall back to multistream-selection. - fmt.Println(">>>>>> upgrader: muxer key from early data is: ", muxerSelected) - if ok && len(muxerSelected) > 0 { //if false && ok { tpt, ok := msmuxer.GetTranspotByKey(muxerSelected) diff --git a/p2p/security/noise/benchmark_test.go b/p2p/security/noise/benchmark_test.go index 0778b8e623..cfd19829ca 100644 --- a/p2p/security/noise/benchmark_test.go +++ b/p2p/security/noise/benchmark_test.go @@ -98,8 +98,6 @@ func (b benchenv) connect(stopTimer bool) (*secureSession, *secureSession) { return initSession.(*secureSession), respSession.(*secureSession) } -// >>>>>> TODO <<<<<< add test cases for non-null muxers. - func drain(r io.Reader, done chan<- error, writeTo io.Writer) { _, err := io.Copy(writeTo, r) done <- err diff --git a/p2p/security/tls/cmd/tlsdiag/client.go b/p2p/security/tls/cmd/tlsdiag/client.go index db7ef31175..96025e2b9b 100644 --- a/p2p/security/tls/cmd/tlsdiag/client.go +++ b/p2p/security/tls/cmd/tlsdiag/client.go @@ -61,5 +61,3 @@ func StartClient() error { fmt.Printf("Received message from server: %s\n", string(data)) return nil } - -// >>>>>> TODO <<<<<< Add an early data test case here. diff --git a/p2p/security/tls/cmd/tlsdiag/server.go b/p2p/security/tls/cmd/tlsdiag/server.go index 6b575e4f88..2b91de5617 100644 --- a/p2p/security/tls/cmd/tlsdiag/server.go +++ b/p2p/security/tls/cmd/tlsdiag/server.go @@ -66,5 +66,3 @@ func handleConn(tp *libp2ptls.Transport, conn net.Conn) error { fmt.Printf("Closing connection to %s\n", conn.RemoteAddr()) return sconn.Close() } - -// >>>>>> TODO <<<<<< Add early data diagcase. diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index 5f59dba068..b842e77b7d 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -98,9 +98,6 @@ func (t *Transport) handshake(ctx context.Context, tlsConn *tls.Conn, keyCh <-ch if err := tlsConn.HandshakeContext(ctx); err != nil { return nil, err } - /// - ac := tlsConn.ConnectionState().NegotiatedProtocol - fmt.Println(">>>>>> TLS negotiated app protocol is: ", ac) // Should be ready by this point, don't block. var remotePubKey ci.PubKey From 57158c6df4e6eaae5ab0286dc1fbcc5a81f0958b Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Mon, 26 Sep 2022 08:39:40 -0700 Subject: [PATCH 07/13] Switch over to passing muxers to security transport constructors --- config/config.go | 2 +- config/constructor_types.go | 26 +++--- config/muxer.go | 2 +- config/reflection_magic.go | 8 +- config/security.go | 17 ++-- config/transport.go | 2 +- core/network/conn.go | 2 +- core/sec/insecure/insecure.go | 4 +- core/sec/insecure/insecure_test.go | 4 +- core/sec/security.go | 8 +- p2p/muxer/muxer-multistream/multistream.go | 4 - p2p/net/conn-security-multistream/ssms.go | 10 +-- .../conn-security-multistream/ssms_test.go | 31 +++----- p2p/net/upgrader/listener_test.go | 8 +- p2p/net/upgrader/upgrader.go | 18 +---- p2p/security/noise/benchmark_test.go | 4 +- p2p/security/noise/session_test.go | 2 +- p2p/security/noise/session_transport.go | 4 +- p2p/security/noise/transport.go | 4 +- p2p/security/noise/transport_test.go | 38 ++++----- p2p/security/tls/cmd/tlsdiag/client.go | 4 +- p2p/security/tls/cmd/tlsdiag/server.go | 4 +- p2p/security/tls/transport.go | 12 +-- p2p/security/tls/transport_test.go | 79 ++++++++++--------- p2p/transport/webtransport/listener.go | 2 +- p2p/transport/webtransport/transport.go | 2 +- 26 files changed, 141 insertions(+), 160 deletions(-) diff --git a/config/config.go b/config/config.go index 0634ef4a36..adbcf1d122 100644 --- a/config/config.go +++ b/config/config.go @@ -173,7 +173,7 @@ func (cfg *Config) addTransports(h host.Host) error { secure = makeInsecureTransport(h.ID(), cfg.PeerKey) } else { var err error - secure, err = makeSecurityMuxer(h, cfg.SecurityTransports) + secure, err = makeSecurityMuxer(h, cfg.SecurityTransports, cfg.Muxers) if err != nil { return err } diff --git a/config/constructor_types.go b/config/constructor_types.go index 53a105ba39..2f43a2f0da 100644 --- a/config/constructor_types.go +++ b/config/constructor_types.go @@ -35,42 +35,46 @@ var ( peerIDType = reflect.TypeOf((peer.ID)("")) pskType = reflect.TypeOf((pnet.PSK)(nil)) resolverType = reflect.TypeOf((*madns.Resolver)(nil)) + muxersType = reflect.TypeOf(([]string)(nil)) ) var argTypes = map[reflect.Type]constructor{ - upgraderType: func(_ host.Host, u transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + upgraderType: func(_ host.Host, u transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { return u }, - hostType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + hostType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { return h }, - networkType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + networkType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { return h.Network() }, - pskType: func(_ host.Host, _ transport.Upgrader, psk pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + pskType: func(_ host.Host, _ transport.Upgrader, psk pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { return psk }, - connGaterType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, cg connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + connGaterType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, cg connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { return cg }, - peerIDType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + peerIDType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { return h.ID() }, - privKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + privKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { return h.Peerstore().PrivKey(h.ID()) }, - pubKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + pubKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { return h.Peerstore().PubKey(h.ID()) }, - pstoreType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + pstoreType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { return h.Peerstore() }, - rcmgrType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, rcmgr network.ResourceManager, _ *madns.Resolver) interface{} { + rcmgrType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, rcmgr network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { return rcmgr }, - resolverType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, r *madns.Resolver) interface{} { + resolverType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, r *madns.Resolver, _ []string) interface{} { return r }, + muxersType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, muxers []string) interface{} { + return muxers + }, } func newArgTypeSet(types ...reflect.Type) map[reflect.Type]constructor { diff --git a/config/muxer.go b/config/muxer.go index 30a3fa2f7e..e7b64c1345 100644 --- a/config/muxer.go +++ b/config/muxer.go @@ -35,7 +35,7 @@ func MuxerConstructor(m interface{}) (MuxC, error) { return nil, err } return func(h host.Host) (network.Multiplexer, error) { - t, err := ctor(h, nil, nil, nil, nil, nil) + t, err := ctor(h, nil, nil, nil, nil, nil, nil) if err != nil { return nil, err } diff --git a/config/reflection_magic.go b/config/reflection_magic.go index bb2f52c6b0..839384b1e8 100644 --- a/config/reflection_magic.go +++ b/config/reflection_magic.go @@ -82,7 +82,7 @@ func callConstructor(c reflect.Value, args []reflect.Value) (interface{}, error) return val, err } -type constructor func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver) interface{} +type constructor func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver, []string) interface{} func makeArgumentConstructors(fnType reflect.Type, argTypes map[reflect.Type]constructor) ([]constructor, error) { params := fnType.NumIn() @@ -133,7 +133,7 @@ func makeConstructor( tptType reflect.Type, argTypes map[reflect.Type]constructor, opts ...interface{}, -) (func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver) (interface{}, error), error) { +) (func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver, []string) (interface{}, error), error) { v := reflect.ValueOf(tpt) // avoid panicing on nil/zero value. if v == (reflect.Value{}) { @@ -157,10 +157,10 @@ func makeConstructor( return nil, err } - return func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater, rcmgr network.ResourceManager, resolver *madns.Resolver) (interface{}, error) { + return func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater, rcmgr network.ResourceManager, resolver *madns.Resolver, muxers []string) (interface{}, error) { arguments := make([]reflect.Value, 0, len(argConstructors)+len(opts)) for i, makeArg := range argConstructors { - if arg := makeArg(h, u, psk, cg, rcmgr, resolver); arg != nil { + if arg := makeArg(h, u, psk, cg, rcmgr, resolver, muxers); arg != nil { arguments = append(arguments, reflect.ValueOf(arg)) } else { // ValueOf an un-typed nil yields a zero reflect diff --git a/config/security.go b/config/security.go index 330cfaf334..c425fdb73c 100644 --- a/config/security.go +++ b/config/security.go @@ -12,7 +12,7 @@ import ( ) // SecC is a security transport constructor. -type SecC func(h host.Host) (sec.SecureTransport, error) +type SecC func(h host.Host, muxers []string) (sec.SecureTransport, error) // MsSecC is a tuple containing a security transport constructor and a protocol // ID. @@ -24,6 +24,7 @@ type MsSecC struct { var securityArgTypes = newArgTypeSet( hostType, networkType, peerIDType, privKeyType, pubKeyType, pstoreType, + muxersType, ) // SecurityConstructor creates a security constructor from the passed parameter @@ -31,7 +32,7 @@ var securityArgTypes = newArgTypeSet( func SecurityConstructor(security interface{}) (SecC, error) { // Already constructed? if t, ok := security.(sec.SecureTransport); ok { - return func(_ host.Host) (sec.SecureTransport, error) { + return func(_ host.Host, _ []string) (sec.SecureTransport, error) { return t, nil }, nil } @@ -40,8 +41,8 @@ func SecurityConstructor(security interface{}) (SecC, error) { if err != nil { return nil, err } - return func(h host.Host) (sec.SecureTransport, error) { - t, err := ctor(h, nil, nil, nil, nil, nil) + return func(h host.Host, muxers []string) (sec.SecureTransport, error) { + t, err := ctor(h, nil, nil, nil, nil, nil, muxers) if err != nil { return nil, err } @@ -55,7 +56,7 @@ func makeInsecureTransport(id peer.ID, privKey crypto.PrivKey) sec.SecureMuxer { return secMuxer } -func makeSecurityMuxer(h host.Host, tpts []MsSecC) (sec.SecureMuxer, error) { +func makeSecurityMuxer(h host.Host, tpts []MsSecC, muxers []MsMuxC) (sec.SecureMuxer, error) { secMuxer := new(csms.SSMuxer) transportSet := make(map[string]struct{}, len(tpts)) for _, tptC := range tpts { @@ -64,8 +65,12 @@ func makeSecurityMuxer(h host.Host, tpts []MsSecC) (sec.SecureMuxer, error) { } transportSet[tptC.ID] = struct{}{} } + muxIds := make([]string, 0, len(muxers)) + for _, muxc := range muxers { + muxIds = append(muxIds, muxc.ID) + } for _, tptC := range tpts { - tpt, err := tptC.SecC(h) + tpt, err := tptC.SecC(h, muxIds) if err != nil { return nil, err } diff --git a/config/transport.go b/config/transport.go index 850357f5a4..6105e77f13 100644 --- a/config/transport.go +++ b/config/transport.go @@ -50,7 +50,7 @@ func TransportConstructor(tpt interface{}, opts ...interface{}) (TptC, error) { return nil, err } return func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater, rcmgr network.ResourceManager, resolver *madns.Resolver) (transport.Transport, error) { - t, err := ctor(h, u, psk, cg, rcmgr, resolver) + t, err := ctor(h, u, psk, cg, rcmgr, resolver, nil) if err != nil { return nil, err } diff --git a/core/network/conn.go b/core/network/conn.go index 66e5ce7538..8e4eb0aefa 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -56,7 +56,7 @@ type ConnSecurity interface { // RemotePublicKey returns the public key of the remote peer. RemotePublicKey() ic.PubKey - // Early data negotiated by the security protocol. Empty if not supported. + // Connection state info of the secured connection. ConnState() ConnectionState } diff --git a/core/sec/insecure/insecure.go b/core/sec/insecure/insecure.go index 4888258439..2d94f43804 100644 --- a/core/sec/insecure/insecure.go +++ b/core/sec/insecure/insecure.go @@ -61,7 +61,7 @@ func (t *Transport) LocalPrivateKey() ci.PrivKey { // // SecureInbound may fail if the remote peer sends an ID and public key that are inconsistent // with each other, or if a network error occurs during the ID exchange. -func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { +func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { conn := &Conn{ Conn: insecure, local: t.id, @@ -88,7 +88,7 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer // SecureOutbound may fail if the remote peer sends an ID and public key that are inconsistent // with each other, or if the ID sent by the remote peer does not match the one dialed. It may // also fail if a network error occurs during the ID exchange. -func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { +func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { conn := &Conn{ Conn: insecure, local: t.id, diff --git a/core/sec/insecure/insecure_test.go b/core/sec/insecure/insecure_test.go index da16772be0..a3ce8314f4 100644 --- a/core/sec/insecure/insecure_test.go +++ b/core/sec/insecure/insecure_test.go @@ -94,9 +94,9 @@ func connect(t *testing.T, clientTpt, serverTpt *Transport, clientExpectsID, ser done := make(chan struct{}) go func() { defer close(done) - clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, clientExpectsID, nil) + clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, clientExpectsID) }() - serverConn, serverErr = serverTpt.SecureInbound(context.TODO(), server, serverExpectsID, nil) + serverConn, serverErr = serverTpt.SecureInbound(context.TODO(), server, serverExpectsID) <-done return } diff --git a/core/sec/security.go b/core/sec/security.go index b76e62b900..c192a56a91 100644 --- a/core/sec/security.go +++ b/core/sec/security.go @@ -20,10 +20,10 @@ type SecureConn interface { type SecureTransport interface { // SecureInbound secures an inbound connection. // If p is empty, connections from any peer are accepted. - SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (SecureConn, error) + SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error) // SecureOutbound secures an outbound connection. - SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (SecureConn, error) + SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error) } // A SecureMuxer is a wrapper around SecureTransport which can select security protocols @@ -33,10 +33,10 @@ type SecureMuxer interface { // The returned boolean indicates whether the connection should be treated as a server // connection; in the case of SecureInbound it should always be true. // If p is empty, connections from any peer are accepted. - SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (SecureConn, bool, error) + SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, bool, error) // SecureOutbound secures an outbound connection. // The returned boolean indicates whether the connection should be treated as a server // connection due to simultaneous open. - SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (SecureConn, bool, error) + SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, bool, error) } diff --git a/p2p/muxer/muxer-multistream/multistream.go b/p2p/muxer/muxer-multistream/multistream.go index 29a1eadf25..9f78a26b3c 100644 --- a/p2p/muxer/muxer-multistream/multistream.go +++ b/p2p/muxer/muxer-multistream/multistream.go @@ -74,10 +74,6 @@ func (t *Transport) NewConn(nc net.Conn, isServer bool, scope network.PeerScope) return tpt.NewConn(nc, isServer, scope) } -func (t *Transport) GetTranspotKeys() []string { - return t.OrderPreference -} - func (t *Transport) GetTranspotByKey(key string) (network.Multiplexer, bool) { val, ok := t.tpts[key] return val, ok diff --git a/p2p/net/conn-security-multistream/ssms.go b/p2p/net/conn-security-multistream/ssms.go index a5e3f07968..595d8dfde6 100644 --- a/p2p/net/conn-security-multistream/ssms.go +++ b/p2p/net/conn-security-multistream/ssms.go @@ -40,18 +40,18 @@ func (sm *SSMuxer) AddTransport(path string, transport sec.SecureTransport) { // SecureInbound secures an inbound connection using this multistream // multiplexed stream security transport. -func (sm *SSMuxer) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, bool, error) { +func (sm *SSMuxer) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) { tpt, _, err := sm.selectProto(ctx, insecure, true) if err != nil { return nil, false, err } - sconn, err := tpt.SecureInbound(ctx, insecure, p, muxers) + sconn, err := tpt.SecureInbound(ctx, insecure, p) return sconn, true, err } // SecureOutbound secures an outbound connection using this multistream // multiplexed stream security transport. -func (sm *SSMuxer) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, bool, error) { +func (sm *SSMuxer) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) { tpt, server, err := sm.selectProto(ctx, insecure, false) if err != nil { return nil, false, err @@ -59,7 +59,7 @@ func (sm *SSMuxer) SecureOutbound(ctx context.Context, insecure net.Conn, p peer var sconn sec.SecureConn if server { - sconn, err = tpt.SecureInbound(ctx, insecure, p, muxers) + sconn, err = tpt.SecureInbound(ctx, insecure, p) if err != nil { return nil, false, fmt.Errorf("failed to secure inbound connection: %s", err) } @@ -70,7 +70,7 @@ func (sm *SSMuxer) SecureOutbound(ctx context.Context, insecure net.Conn, p peer return nil, false, fmt.Errorf("unexpected peer") } } else { - sconn, err = tpt.SecureOutbound(ctx, insecure, p, muxers) + sconn, err = tpt.SecureOutbound(ctx, insecure, p) } return sconn, server, err diff --git a/p2p/net/conn-security-multistream/ssms_test.go b/p2p/net/conn-security-multistream/ssms_test.go index 67ea53939d..5aa5db352d 100644 --- a/p2p/net/conn-security-multistream/ssms_test.go +++ b/p2p/net/conn-security-multistream/ssms_test.go @@ -28,23 +28,17 @@ type TransportAdapter struct { mux *SSMuxer } -func (sm *TransportAdapter) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { - sconn, _, err := sm.mux.SecureInbound(ctx, insecure, p, muxers) +func (sm *TransportAdapter) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { + sconn, _, err := sm.mux.SecureInbound(ctx, insecure, p) return sconn, err } -func (sm *TransportAdapter) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { - sconn, _, err := sm.mux.SecureOutbound(ctx, insecure, p, muxers) +func (sm *TransportAdapter) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { + sconn, _, err := sm.mux.SecureOutbound(ctx, insecure, p) return sconn, err } -var clientMuxerList = [][]string{{}, {"muxer1/1.0.0", "muxer2/1.0.1"}, {"muxer1"}, {}} -var serverMuxerList = [][]string{{}, {"muxer2/1.0.1", "muxer1/1.0.0"}, {}, {"muxer1"}} -var insecureExpectedMuxers = []string{"", "", "", ""} - -const numMuxers = 4 - -func commonProto(t *testing.T, serverMuxers, clientMuxers []string, expectedMuxer string) { +func TestCommonProto(t *testing.T) { privA, idA := newPeer(t) privB, idB := newPeer(t) @@ -63,7 +57,7 @@ func commonProto(t *testing.T, serverMuxers, clientMuxers []string, expectedMuxe go func() { conn, err := ln.Accept() require.NoError(t, err) - c, err := muxB.SecureInbound(context.Background(), conn, idA, serverMuxers) + c, err := muxB.SecureInbound(context.Background(), conn, idA) require.NoError(t, err) connChan <- c }() @@ -73,7 +67,7 @@ func commonProto(t *testing.T, serverMuxers, clientMuxers []string, expectedMuxe cconn, err := net.Dial("tcp", ln.Addr().String()) require.NoError(t, err) - cc, err := muxA.SecureOutbound(context.Background(), cconn, idB, clientMuxers) + cc, err := muxA.SecureOutbound(context.Background(), cconn, idB) require.NoError(t, err) require.Equal(t, cc.LocalPeer(), idA) require.Equal(t, cc.RemotePeer(), idB) @@ -87,13 +81,6 @@ func commonProto(t *testing.T, serverMuxers, clientMuxers []string, expectedMuxe b, err := io.ReadAll(sc) require.NoError(t, err) require.Equal(t, "foobar", string(b)) - require.Equal(t, expectedMuxer, cc.ConnState().EarlyData) -} - -func TestCommonProto(t *testing.T) { - for i := 0; i < numMuxers; i++ { - commonProto(t, serverMuxerList[i], clientMuxerList[i], insecureExpectedMuxers[i]) - } } func TestNoCommonProto(t *testing.T) { @@ -116,7 +103,7 @@ func TestNoCommonProto(t *testing.T) { go func() { defer wg.Done() defer a.Close() - _, _, err := at.SecureInbound(ctx, a, "", nil) + _, _, err := at.SecureInbound(ctx, a, "") if err == nil { t.Error("connection should have failed") } @@ -125,7 +112,7 @@ func TestNoCommonProto(t *testing.T) { go func() { defer wg.Done() defer b.Close() - _, _, err := bt.SecureOutbound(ctx, b, "peerA", nil) + _, _, err := bt.SecureOutbound(ctx, b, "peerA") if err == nil { t.Error("connection should have failed") } diff --git a/p2p/net/upgrader/listener_test.go b/p2p/net/upgrader/listener_test.go index 523f712ed8..82c3952ef8 100644 --- a/p2p/net/upgrader/listener_test.go +++ b/p2p/net/upgrader/listener_test.go @@ -30,13 +30,13 @@ type MuxAdapter struct { var _ sec.SecureMuxer = &MuxAdapter{} -func (mux *MuxAdapter) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, bool, error) { - sconn, err := mux.tpt.SecureInbound(ctx, insecure, p, muxers) +func (mux *MuxAdapter) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) { + sconn, err := mux.tpt.SecureInbound(ctx, insecure, p) return sconn, true, err } -func (mux *MuxAdapter) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, bool, error) { - sconn, err := mux.tpt.SecureOutbound(ctx, insecure, p, muxers) +func (mux *MuxAdapter) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, bool, error) { + sconn, err := mux.tpt.SecureOutbound(ctx, insecure, p) return sconn, false, err } diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index f16bf95555..4785b36ad7 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -190,31 +190,18 @@ func (u *upgrader) upgrade(ctx context.Context, t transport.Transport, maconn ma } func (u *upgrader) setupSecurity(ctx context.Context, conn net.Conn, p peer.ID, dir network.Direction) (sec.SecureConn, bool, error) { - // Add candidate muxers to the security layer handshake process to save - // muxer negotiation round trip if possible. - // TODO: explore if there is a way of extracting muxers other than type assertion. - muxers := []string{} - msmuxer, ok := u.muxer.(*msmux.Transport) - if ok { - muxers = msmuxer.GetTranspotKeys() - } - if dir == network.DirInbound { - return u.secure.SecureInbound(ctx, conn, p, muxers) + return u.secure.SecureInbound(ctx, conn, p) } - return u.secure.SecureOutbound(ctx, conn, p, muxers) + return u.secure.SecureOutbound(ctx, conn, p) } func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server bool, scope network.PeerScope) (network.MuxedConn, error) { // TODO: The muxer should take a context. - - //// msmuxer, ok := u.muxer.(*msmux.Transport) muxerSelected := conn.ConnState().EarlyData - // Use muxer selected from security handshake if available. Otherwise fall back to multistream-selection. if ok && len(muxerSelected) > 0 { - //if false && ok { tpt, ok := msmuxer.GetTranspotByKey(muxerSelected) if !ok { return nil, fmt.Errorf("selected a muxer we don't have a transport for") @@ -223,7 +210,6 @@ func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server b return tpt.NewConn(conn, server, scope) } - //// done := make(chan struct{}) var smconn network.MuxedConn diff --git a/p2p/security/noise/benchmark_test.go b/p2p/security/noise/benchmark_test.go index cfd19829ca..52454f5959 100644 --- a/p2p/security/noise/benchmark_test.go +++ b/p2p/security/noise/benchmark_test.go @@ -81,10 +81,10 @@ func (b benchenv) connect(stopTimer bool) (*secureSession, *secureSession) { done := make(chan struct{}) go func() { defer close(done) - initSession, initErr = b.initTpt.SecureOutbound(context.TODO(), initConn, b.respTpt.localID, nil) + initSession, initErr = b.initTpt.SecureOutbound(context.TODO(), initConn, b.respTpt.localID) }() - respSession, respErr := b.respTpt.SecureInbound(context.TODO(), respConn, "", nil) + respSession, respErr := b.respTpt.SecureInbound(context.TODO(), respConn, "") <-done if initErr != nil { diff --git a/p2p/security/noise/session_test.go b/p2p/security/noise/session_test.go index f4b3cf31fc..85de01b2ba 100644 --- a/p2p/security/noise/session_test.go +++ b/p2p/security/noise/session_test.go @@ -20,7 +20,7 @@ func TestContextCancellationRespected(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := initTransport.SecureOutbound(ctx, init, respTransport.localID, nil) + _, err := initTransport.SecureOutbound(ctx, init, respTransport.localID) require.Error(t, err) require.Equal(t, ctx.Err(), err) } diff --git a/p2p/security/noise/session_transport.go b/p2p/security/noise/session_transport.go index 295d9779ed..7b468edbb4 100644 --- a/p2p/security/noise/session_transport.go +++ b/p2p/security/noise/session_transport.go @@ -64,7 +64,7 @@ type SessionTransport struct { // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. -func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { +func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) @@ -76,6 +76,6 @@ func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, } // SecureOutbound runs the Noise handshake as the initiator. -func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { +func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, true) } diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index 0fd20f4b66..c6923698cc 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -40,7 +40,7 @@ func New(privkey crypto.PrivKey) (*Transport, error) { // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. -func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { +func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { c, err := newSecureSession(t, ctx, insecure, p, nil, nil, nil, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) @@ -52,7 +52,7 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer } // SecureOutbound runs the Noise handshake as the initiator. -func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { +func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { return newSecureSession(t, ctx, insecure, p, nil, nil, nil, true) } diff --git a/p2p/security/noise/transport_test.go b/p2p/security/noise/transport_test.go index 86d25da61b..2fa90d06ef 100644 --- a/p2p/security/noise/transport_test.go +++ b/p2p/security/noise/transport_test.go @@ -79,10 +79,10 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess done := make(chan struct{}) go func() { defer close(done) - initConn, initErr = initTransport.SecureOutbound(context.Background(), init, respTransport.localID, nil) + initConn, initErr = initTransport.SecureOutbound(context.Background(), init, respTransport.localID) }() - respConn, respErr := respTransport.SecureInbound(context.Background(), resp, "", nil) + respConn, respErr := respTransport.SecureInbound(context.Background(), resp, "") <-done if initErr != nil { @@ -107,7 +107,7 @@ func TestDeadlines(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - _, err := initTransport.SecureOutbound(ctx, init, respTransport.localID, nil) + _, err := initTransport.SecureOutbound(ctx, init, respTransport.localID) if err == nil { t.Fatalf("expected i/o timeout err; got: %s", err) } @@ -172,7 +172,7 @@ func TestPeerIDMatch(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID, nil) + conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID) assert.NoError(t, err) assert.Equal(t, conn.RemotePeer(), respTransport.localID) b := make([]byte, 6) @@ -181,7 +181,7 @@ func TestPeerIDMatch(t *testing.T) { assert.Equal(t, b, []byte("foobar")) }() - conn, err := respTransport.SecureInbound(context.Background(), resp, initTransport.localID, nil) + conn, err := respTransport.SecureInbound(context.Background(), resp, initTransport.localID) require.NoError(t, err) require.Equal(t, conn.RemotePeer(), initTransport.localID) _, err = conn.Write([]byte("foobar")) @@ -195,11 +195,11 @@ func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) { errChan := make(chan error) go func() { - _, err := initTransport.SecureOutbound(context.Background(), init, "a-random-peer-id", nil) + _, err := initTransport.SecureOutbound(context.Background(), init, "a-random-peer-id") errChan <- err }() - _, err := respTransport.SecureInbound(context.Background(), resp, "", nil) + _, err := respTransport.SecureInbound(context.Background(), resp, "") require.Error(t, err) initErr := <-errChan @@ -215,13 +215,13 @@ func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID, nil) + conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID) assert.NoError(t, err) _, err = conn.Read([]byte{0}) assert.Error(t, err) }() - _, err := respTransport.SecureInbound(context.Background(), resp, "a-random-peer-id", nil) + _, err := respTransport.SecureInbound(context.Background(), resp, "a-random-peer-id") require.Error(t, err, "expected responder to fail with peer ID mismatch error") <-done } @@ -388,7 +388,7 @@ func TestPrologueMatches(t *testing.T) { tpt, err := initTransport. WithSessionOptions(Prologue(commonPrologue)) require.NoError(t, err) - conn, err := tpt.SecureOutbound(context.Background(), initConn, respTransport.localID, nil) + conn, err := tpt.SecureOutbound(context.Background(), initConn, respTransport.localID) require.NoError(t, err) defer conn.Close() }() @@ -396,7 +396,7 @@ func TestPrologueMatches(t *testing.T) { tpt, err := respTransport. WithSessionOptions(Prologue(commonPrologue)) require.NoError(t, err) - conn, err := tpt.SecureInbound(context.Background(), respConn, "", nil) + conn, err := tpt.SecureInbound(context.Background(), respConn, "") require.NoError(t, err) defer conn.Close() <-done @@ -416,14 +416,14 @@ func TestPrologueDoesNotMatchFailsHandshake(t *testing.T) { tpt, err := initTransport. WithSessionOptions(Prologue(initPrologue)) require.NoError(t, err) - _, err = tpt.SecureOutbound(context.Background(), initConn, respTransport.localID, nil) + _, err = tpt.SecureOutbound(context.Background(), initConn, respTransport.localID) require.Error(t, err) }() tpt, err := respTransport.WithSessionOptions(Prologue(respPrologue)) require.NoError(t, err) - _, err = tpt.SecureInbound(context.Background(), respConn, "", nil) + _, err = tpt.SecureInbound(context.Background(), respConn, "") require.Error(t, err) <-done } @@ -460,11 +460,11 @@ func TestEarlyDataAccepted(t *testing.T) { errChan := make(chan error) go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "", nil) + _, err := respTransport.SecureInbound(context.Background(), initConn, "") errChan <- err }() - conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID, nil) + conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) require.NoError(t, err) select { case <-time.After(500 * time.Millisecond): @@ -513,14 +513,14 @@ func TestEarlyDataRejected(t *testing.T) { errChan := make(chan error) go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "", nil) + _, err := respTransport.SecureInbound(context.Background(), initConn, "") errChan <- err }() // As early data is sent with the last handshake message, the handshake will appear // to succeed for the client. var conn sec.SecureConn - conn, clientErr = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID, nil) + conn, clientErr = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) if clientErr == nil { _, clientErr = conn.Read([]byte{0}) } @@ -571,11 +571,11 @@ func TestEarlyfffDataAcceptedWithNoHandler(t *testing.T) { errChan := make(chan error) go func() { - _, err := respTransport.SecureInbound(context.Background(), initConn, "", nil) + _, err := respTransport.SecureInbound(context.Background(), initConn, "") errChan <- err }() - conn, err := initTransport.SecureOutbound(context.Background(), respConn, respTransport.localID, nil) + conn, err := initTransport.SecureOutbound(context.Background(), respConn, respTransport.localID) require.NoError(t, err) defer conn.Close() diff --git a/p2p/security/tls/cmd/tlsdiag/client.go b/p2p/security/tls/cmd/tlsdiag/client.go index 96025e2b9b..2292bfe0e9 100644 --- a/p2p/security/tls/cmd/tlsdiag/client.go +++ b/p2p/security/tls/cmd/tlsdiag/client.go @@ -34,7 +34,7 @@ func StartClient() error { return err } fmt.Printf(" Peer ID: %s\n", id.Pretty()) - tp, err := libp2ptls.New(priv) + tp, err := libp2ptls.New(priv, nil) if err != nil { return err } @@ -49,7 +49,7 @@ func StartClient() error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - sconn, err := tp.SecureOutbound(ctx, conn, peerID, nil) + sconn, err := tp.SecureOutbound(ctx, conn, peerID) if err != nil { return err } diff --git a/p2p/security/tls/cmd/tlsdiag/server.go b/p2p/security/tls/cmd/tlsdiag/server.go index 2b91de5617..05e4be3f16 100644 --- a/p2p/security/tls/cmd/tlsdiag/server.go +++ b/p2p/security/tls/cmd/tlsdiag/server.go @@ -27,7 +27,7 @@ func StartServer() error { return err } fmt.Printf(" Peer ID: %s\n", id.Pretty()) - tp, err := libp2ptls.New(priv) + tp, err := libp2ptls.New(priv, nil) if err != nil { return err } @@ -57,7 +57,7 @@ func StartServer() error { func handleConn(tp *libp2ptls.Transport, conn net.Conn) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - sconn, err := tp.SecureInbound(ctx, conn, "", nil) + sconn, err := tp.SecureInbound(ctx, conn, "") if err != nil { return err } diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index b842e77b7d..b0ee6bea54 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -27,10 +27,11 @@ type Transport struct { localPeer peer.ID privKey ci.PrivKey + muxers []string } // New creates a TLS encrypted transport -func New(key ci.PrivKey) (*Transport, error) { +func New(key ci.PrivKey, stream_muxers []string) (*Transport, error) { id, err := peer.IDFromPrivateKey(key) if err != nil { return nil, err @@ -38,6 +39,7 @@ func New(key ci.PrivKey) (*Transport, error) { t := &Transport{ localPeer: id, privKey: key, + muxers: stream_muxers, } identity, err := NewIdentity(key) @@ -52,10 +54,10 @@ var _ sec.SecureTransport = &Transport{} // SecureInbound runs the TLS handshake as a server. // If p is empty, connections from any peer are accepted. -func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { +func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { config, keyCh := t.identity.ConfigForPeer(p) // Prepend the prefered muxers list to TLS config. - config.NextProtos = append(muxers, config.NextProtos...) + config.NextProtos = append(t.muxers, config.NextProtos...) cs, err := t.handshake(ctx, tls.Server(insecure, config), keyCh) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) @@ -74,10 +76,10 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer // application data immediately afterwards. // If the handshake fails, the server will close the connection. The client will // notice this after 1 RTT when calling Read. -func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID, muxers []string) (sec.SecureConn, error) { +func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { config, keyCh := t.identity.ConfigForPeer(p) // Prepend the prefered muxers list to TLS config. - config.NextProtos = append(muxers, config.NextProtos...) + config.NextProtos = append(t.muxers, config.NextProtos...) cs, err := t.handshake(ctx, tls.Client(insecure, config), keyCh) if err != nil { insecure.Close() diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index 8a12dc2bef..5339f6f9fc 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -28,11 +28,11 @@ import ( "github.com/stretchr/testify/require" ) -var clientMuxerList = [][]string{{}, {"muxer1/1.0.0", "muxer2/1.0.1"}, {"muxer1"}, {}} -var serverMuxerList = [][]string{{}, {"muxer2/1.0.1", "muxer1/1.0.0"}, {}, {"muxer1"}} -var expectedMuxers = []string{"", "muxer2/1.0.1", "", ""} +var clientMuxerList = [][]string{{}, {"muxer1/1.0.0", "muxer2/1.0.1"}, {"muxer1"}, {}, {"muxer1"}} +var serverMuxerList = [][]string{{}, {"muxer2/1.0.1", "muxer1/1.0.0"}, {}, {"muxer1"}, {"muxer2"}} +var expectedMuxers = []string{"", "muxer2/1.0.1", "", "", ""} -const numMuxers = 4 +const numMuxers = 5 func createPeer(t *testing.T) (peer.ID, ic.PrivKey) { var priv ic.PrivKey @@ -91,20 +91,18 @@ func TestHandshakeSucceeds(t *testing.T) { clientID, clientKey := createPeer(t) serverID, serverKey := createPeer(t) var expectedMuxer string - var clientMuxers []string - var serverMuxers []string handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport) { clientInsecureConn, serverInsecureConn := connect(t) serverConnChan := make(chan sec.SecureConn) go func() { - serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", serverMuxers) + serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") require.NoError(t, err) serverConnChan <- serverConn }() - clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, clientMuxers) + clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) require.NoError(t, err) defer clientConn.Close() @@ -135,15 +133,16 @@ func TestHandshakeSucceeds(t *testing.T) { } // Use standard transports with default TLS configuration - clientTransport, err := New(clientKey) - require.NoError(t, err) - serverTransport, err := New(serverKey) - require.NoError(t, err) + var clientTransport *Transport + var err error + var serverTransport *Transport for i := 0; i < numMuxers; i++ { expectedMuxer = expectedMuxers[i] - clientMuxers = clientMuxerList[i] - serverMuxers = serverMuxerList[i] + clientTransport, err = New(clientKey, clientMuxerList[i]) + require.NoError(t, err) + serverTransport, err = New(serverKey, serverMuxerList[i]) + require.NoError(t, err) t.Run("standard TLS with extension not critical", func(t *testing.T) { handshake(t, clientTransport, serverTransport) }) @@ -180,8 +179,11 @@ func TestHandshakeSucceeds(t *testing.T) { for i := 0; i < numMuxers; i++ { expectedMuxer = expectedMuxers[i] - clientMuxers = clientMuxerList[i] - serverMuxers = serverMuxerList[i] + clientTransport, err = New(clientKey, clientMuxerList[i]) + require.NoError(t, err) + serverTransport, err = New(serverKey, serverMuxerList[i]) + require.NoError(t, err) + t.Run("custom TLS with extension not critical", func(t *testing.T) { handshake(t, clientTransport, serverTransport) }) @@ -212,23 +214,22 @@ func TestHandshakeConnectionCancelations(t *testing.T) { _, clientKey := createPeer(t) serverID, serverKey := createPeer(t) - clientTransport, err := New(clientKey) - require.NoError(t, err) - serverTransport, err := New(serverKey) - require.NoError(t, err) - for i := 0; i < numMuxers; i++ { + clientTransport, err := New(clientKey, clientMuxerList[i]) + require.NoError(t, err) + serverTransport, err := New(serverKey, serverMuxerList[i]) + require.NoError(t, err) t.Run("cancel outgoing connection", func(t *testing.T) { clientInsecureConn, serverInsecureConn := connect(t) errChan := make(chan error) go func() { - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", serverMuxerList[i]) + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") errChan <- err }() ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID, clientMuxerList[i]) + _, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID) require.ErrorIs(t, err, context.Canceled) require.Error(t, <-errChan) }) @@ -240,7 +241,7 @@ func TestHandshakeConnectionCancelations(t *testing.T) { go func() { ctx, cancel := context.WithCancel(context.Background()) cancel() - conn, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "", serverMuxerList[i]) + conn, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "") // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, // and closes the underlying connection when that context is canceled. // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. @@ -249,7 +250,7 @@ func TestHandshakeConnectionCancelations(t *testing.T) { } errChan <- err }() - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, clientMuxerList[i]) + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) require.Error(t, err) require.ErrorIs(t, <-errChan, context.Canceled) }) @@ -260,9 +261,9 @@ func TestPeerIDMismatch(t *testing.T) { _, clientKey := createPeer(t) serverID, serverKey := createPeer(t) - serverTransport, err := New(serverKey) + serverTransport, err := New(serverKey, nil) require.NoError(t, err) - clientTransport, err := New(clientKey) + clientTransport, err := New(clientKey, nil) require.NoError(t, err) t.Run("for outgoing connections", func(t *testing.T) { @@ -270,7 +271,7 @@ func TestPeerIDMismatch(t *testing.T) { errChan := make(chan error) go func() { - conn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", nil) + conn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, // and closes the underlying connection when that context is canceled. // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. @@ -282,7 +283,7 @@ func TestPeerIDMismatch(t *testing.T) { // dial, but expect the wrong peer ID thirdPartyID, _ := createPeer(t) - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, thirdPartyID, nil) + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, thirdPartyID) require.Error(t, err) require.Contains(t, err.Error(), "peer IDs don't match") @@ -303,11 +304,11 @@ func TestPeerIDMismatch(t *testing.T) { go func() { thirdPartyID, _ := createPeer(t) // expect the wrong peer ID - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, thirdPartyID, nil) + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, thirdPartyID) errChan <- err }() - conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, nil) + conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) require.NoError(t, err) _, err = conn.Read([]byte{0}) require.Error(t, err) @@ -537,9 +538,9 @@ func TestInvalidCerts(t *testing.T) { tr := transforms[i] t.Run(fmt.Sprintf("client offending: %s", tr.name), func(t *testing.T) { - serverTransport, err := New(serverKey) + serverTransport, err := New(serverKey, nil) require.NoError(t, err) - clientTransport, err := New(clientKey) + clientTransport, err := New(clientKey, nil) require.NoError(t, err) tr.apply(clientTransport.identity) @@ -547,11 +548,11 @@ func TestInvalidCerts(t *testing.T) { serverErrChan := make(chan error) go func() { - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", nil) + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") serverErrChan <- err }() - conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, nil) + conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) require.NoError(t, err) clientErrChan := make(chan error) go func() { @@ -580,21 +581,21 @@ func TestInvalidCerts(t *testing.T) { }) t.Run(fmt.Sprintf("server offending: %s", tr.name), func(t *testing.T) { - serverTransport, err := New(serverKey) + serverTransport, err := New(serverKey, nil) require.NoError(t, err) tr.apply(serverTransport.identity) - clientTransport, err := New(clientKey) + clientTransport, err := New(clientKey, nil) require.NoError(t, err) clientInsecureConn, serverInsecureConn := connect(t) errChan := make(chan error) go func() { - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "", nil) + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") errChan <- err }() - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID, nil) + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) require.Error(t, err) tr.checkErr(t, err) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 59277d6ce3..317afbe41d 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -209,7 +209,7 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (* if err != nil { return nil, fmt.Errorf("failed to initialize Noise session: %w", err) } - c, err := n.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "", nil) + c, err := n.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") if err != nil { return nil, err } diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 987a8496b2..8131753a7f 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -229,7 +229,7 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p if err != nil { return nil, fmt.Errorf("failed to create Noise transport: %w", err) } - c, err := n.SecureOutbound(ctx, &webtransportStream{Stream: str, wsess: sess}, p, nil) + c, err := n.SecureOutbound(ctx, &webtransportStream{Stream: str, wsess: sess}, p) if err != nil { return nil, err } From ec0a96c1f44aa92114265f87d20c8098a4ee55b0 Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Wed, 28 Sep 2022 10:26:47 -0700 Subject: [PATCH 08/13] Address feedback points --- config/constructor_types.go | 27 +- config/reflection_magic.go | 7 +- config/security.go | 11 +- core/network/conn.go | 2 +- p2p/muxer/muxer-multistream/multistream.go | 2 +- p2p/net/upgrader/upgrader.go | 6 +- p2p/security/noise/session.go | 6 +- p2p/security/tls/transport.go | 23 +- p2p/security/tls/transport_test.go | 199 ++++++- p2p/security/tls/transport_test.go-old | 619 +++++++++++++++++++++ 10 files changed, 838 insertions(+), 64 deletions(-) create mode 100644 p2p/security/tls/transport_test.go-old diff --git a/config/constructor_types.go b/config/constructor_types.go index 2f43a2f0da..9c14df2e2c 100644 --- a/config/constructor_types.go +++ b/config/constructor_types.go @@ -11,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/pnet" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/transport" @@ -35,44 +36,44 @@ var ( peerIDType = reflect.TypeOf((peer.ID)("")) pskType = reflect.TypeOf((pnet.PSK)(nil)) resolverType = reflect.TypeOf((*madns.Resolver)(nil)) - muxersType = reflect.TypeOf(([]string)(nil)) + muxersType = reflect.TypeOf(([]protocol.ID)(nil)) ) var argTypes = map[reflect.Type]constructor{ - upgraderType: func(_ host.Host, u transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { + upgraderType: func(_ host.Host, u transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return u }, - hostType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { + hostType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return h }, - networkType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { + networkType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return h.Network() }, - pskType: func(_ host.Host, _ transport.Upgrader, psk pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { + pskType: func(_ host.Host, _ transport.Upgrader, psk pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return psk }, - connGaterType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, cg connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { + connGaterType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, cg connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return cg }, - peerIDType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { + peerIDType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return h.ID() }, - privKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { + privKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return h.Peerstore().PrivKey(h.ID()) }, - pubKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { + pubKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return h.Peerstore().PubKey(h.ID()) }, - pstoreType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { + pstoreType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return h.Peerstore() }, - rcmgrType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, rcmgr network.ResourceManager, _ *madns.Resolver, _ []string) interface{} { + rcmgrType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, rcmgr network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return rcmgr }, - resolverType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, r *madns.Resolver, _ []string) interface{} { + resolverType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, r *madns.Resolver, _ []protocol.ID) interface{} { return r }, - muxersType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, muxers []string) interface{} { + muxersType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, muxers []protocol.ID) interface{} { return muxers }, } diff --git a/config/reflection_magic.go b/config/reflection_magic.go index 839384b1e8..0189872abb 100644 --- a/config/reflection_magic.go +++ b/config/reflection_magic.go @@ -10,6 +10,7 @@ import ( "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/pnet" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/transport" madns "github.com/multiformats/go-multiaddr-dns" @@ -82,7 +83,7 @@ func callConstructor(c reflect.Value, args []reflect.Value) (interface{}, error) return val, err } -type constructor func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver, []string) interface{} +type constructor func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver, []protocol.ID) interface{} func makeArgumentConstructors(fnType reflect.Type, argTypes map[reflect.Type]constructor) ([]constructor, error) { params := fnType.NumIn() @@ -133,7 +134,7 @@ func makeConstructor( tptType reflect.Type, argTypes map[reflect.Type]constructor, opts ...interface{}, -) (func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver, []string) (interface{}, error), error) { +) (func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver, []protocol.ID) (interface{}, error), error) { v := reflect.ValueOf(tpt) // avoid panicing on nil/zero value. if v == (reflect.Value{}) { @@ -157,7 +158,7 @@ func makeConstructor( return nil, err } - return func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater, rcmgr network.ResourceManager, resolver *madns.Resolver, muxers []string) (interface{}, error) { + return func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater, rcmgr network.ResourceManager, resolver *madns.Resolver, muxers []protocol.ID) (interface{}, error) { arguments := make([]reflect.Value, 0, len(argConstructors)+len(opts)) for i, makeArg := range argConstructors { if arg := makeArg(h, u, psk, cg, rcmgr, resolver, muxers); arg != nil { diff --git a/config/security.go b/config/security.go index c425fdb73c..edc4a38b16 100644 --- a/config/security.go +++ b/config/security.go @@ -6,13 +6,14 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/sec/insecure" csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" ) // SecC is a security transport constructor. -type SecC func(h host.Host, muxers []string) (sec.SecureTransport, error) +type SecC func(h host.Host, muxers []protocol.ID) (sec.SecureTransport, error) // MsSecC is a tuple containing a security transport constructor and a protocol // ID. @@ -32,7 +33,7 @@ var securityArgTypes = newArgTypeSet( func SecurityConstructor(security interface{}) (SecC, error) { // Already constructed? if t, ok := security.(sec.SecureTransport); ok { - return func(_ host.Host, _ []string) (sec.SecureTransport, error) { + return func(_ host.Host, _ []protocol.ID) (sec.SecureTransport, error) { return t, nil }, nil } @@ -41,7 +42,7 @@ func SecurityConstructor(security interface{}) (SecC, error) { if err != nil { return nil, err } - return func(h host.Host, muxers []string) (sec.SecureTransport, error) { + return func(h host.Host, muxers []protocol.ID) (sec.SecureTransport, error) { t, err := ctor(h, nil, nil, nil, nil, nil, muxers) if err != nil { return nil, err @@ -65,9 +66,9 @@ func makeSecurityMuxer(h host.Host, tpts []MsSecC, muxers []MsMuxC) (sec.SecureM } transportSet[tptC.ID] = struct{}{} } - muxIds := make([]string, 0, len(muxers)) + muxIds := make([]protocol.ID, 0, len(muxers)) for _, muxc := range muxers { - muxIds = append(muxIds, muxc.ID) + muxIds = append(muxIds, (protocol.ID)(muxc.ID)) } for _, tptC := range tpts { tpt, err := tptC.SecC(h, muxIds) diff --git a/core/network/conn.go b/core/network/conn.go index 8e4eb0aefa..e00ad59f83 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -38,7 +38,7 @@ type Conn interface { type ConnectionState struct { // Early data result derived from security protocol handshake. // For example, Noise handshake payload or TLS/ALPN negotiation. - EarlyData string + NextProto string } // ConnSecurity is the interface that one can mix into a connection interface to diff --git a/p2p/muxer/muxer-multistream/multistream.go b/p2p/muxer/muxer-multistream/multistream.go index 9f78a26b3c..e81ae0ded3 100644 --- a/p2p/muxer/muxer-multistream/multistream.go +++ b/p2p/muxer/muxer-multistream/multistream.go @@ -74,7 +74,7 @@ func (t *Transport) NewConn(nc net.Conn, isServer bool, scope network.PeerScope) return tpt.NewConn(nc, isServer, scope) } -func (t *Transport) GetTranspotByKey(key string) (network.Multiplexer, bool) { +func (t *Transport) GetTransportByKey(key string) (network.Multiplexer, bool) { val, ok := t.tpts[key] return val, ok } diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index 4785b36ad7..28a839be6c 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -197,12 +197,11 @@ func (u *upgrader) setupSecurity(ctx context.Context, conn net.Conn, p peer.ID, } func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server bool, scope network.PeerScope) (network.MuxedConn, error) { - // TODO: The muxer should take a context. msmuxer, ok := u.muxer.(*msmux.Transport) - muxerSelected := conn.ConnState().EarlyData + muxerSelected := conn.ConnState().NextProto // Use muxer selected from security handshake if available. Otherwise fall back to multistream-selection. if ok && len(muxerSelected) > 0 { - tpt, ok := msmuxer.GetTranspotByKey(muxerSelected) + tpt, ok := msmuxer.GetTransportByKey(muxerSelected) if !ok { return nil, fmt.Errorf("selected a muxer we don't have a transport for") } @@ -214,6 +213,7 @@ func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server b var smconn network.MuxedConn var err error + // TODO: The muxer should take a context. go func() { defer close(done) smconn, err = u.muxer.NewConn(conn, server, scope) diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index cfd5ae2ef3..4bdcc3710d 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -41,8 +41,8 @@ type secureSession struct { initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler - // Early data derived from handshaking. It is empty if not supported. - earlyData string + // Next protocol derived from handshaking. It is empty if not supported. + nextProto string } // newSecureSession creates a Noise session over the given insecureConn Conn, using @@ -113,7 +113,7 @@ func (s *secureSession) RemotePublicKey() crypto.PubKey { } func (s *secureSession) ConnState() network.ConnectionState { - return network.ConnectionState{EarlyData: s.earlyData} + return network.ConnectionState{NextProto: s.nextProto} } func (s *secureSession) SetDeadline(t time.Time) error { diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index b0ee6bea54..5ae91be844 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -13,6 +13,7 @@ import ( ci "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" manet "github.com/multiformats/go-multiaddr/net" @@ -27,11 +28,11 @@ type Transport struct { localPeer peer.ID privKey ci.PrivKey - muxers []string + muxers []protocol.ID } // New creates a TLS encrypted transport -func New(key ci.PrivKey, stream_muxers []string) (*Transport, error) { +func New(key ci.PrivKey, muxers []protocol.ID) (*Transport, error) { id, err := peer.IDFromPrivateKey(key) if err != nil { return nil, err @@ -39,7 +40,7 @@ func New(key ci.PrivKey, stream_muxers []string) (*Transport, error) { t := &Transport{ localPeer: id, privKey: key, - muxers: stream_muxers, + muxers: muxers, } identity, err := NewIdentity(key) @@ -56,8 +57,12 @@ var _ sec.SecureTransport = &Transport{} // If p is empty, connections from any peer are accepted. func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { config, keyCh := t.identity.ConfigForPeer(p) + muxers := make([]string, 0, len(t.muxers)) + for _, muxer := range t.muxers { + muxers = append(muxers, (string)(muxer)) + } // Prepend the prefered muxers list to TLS config. - config.NextProtos = append(t.muxers, config.NextProtos...) + config.NextProtos = append(muxers, config.NextProtos...) cs, err := t.handshake(ctx, tls.Server(insecure, config), keyCh) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) @@ -78,8 +83,12 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer // notice this after 1 RTT when calling Read. func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { config, keyCh := t.identity.ConfigForPeer(p) + muxers := make([]string, 0, len(t.muxers)) + for _, muxer := range t.muxers { + muxers = append(muxers, (string)(muxer)) + } // Prepend the prefered muxers list to TLS config. - config.NextProtos = append(t.muxers, config.NextProtos...) + config.NextProtos = append(muxers, config.NextProtos...) cs, err := t.handshake(ctx, tls.Client(insecure, config), keyCh) if err != nil { insecure.Close() @@ -121,7 +130,7 @@ func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.Se } nextProto := tlsConn.ConnectionState().NegotiatedProtocol - if len(nextProto) > 0 && nextProto == "libp2p" { + if nextProto == "libp2p" { nextProto = "" } @@ -131,6 +140,6 @@ func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.Se privKey: t.privKey, remotePeer: remotePeerID, remotePubKey: remotePubKey, - connectionState: network.ConnectionState{EarlyData: nextProto}, + connectionState: network.ConnectionState{NextProto: nextProto}, }, nil } diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index 86eff902bc..0921a501bb 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -22,18 +22,13 @@ import ( ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -var clientMuxerList = [][]string{{}, {"muxer1/1.0.0", "muxer2/1.0.1"}, {"muxer1"}, {}, {"muxer1"}} -var serverMuxerList = [][]string{{}, {"muxer2/1.0.1", "muxer1/1.0.0"}, {}, {"muxer1"}, {"muxer2"}} -var expectedMuxers = []string{"", "muxer2/1.0.1", "", "", ""} - -const numMuxers = 5 - func createPeer(t *testing.T) (peer.ID, ic.PrivKey) { var priv ic.PrivKey var err error @@ -90,7 +85,6 @@ func isWindowsTCPCloseError(err error) bool { func TestHandshakeSucceeds(t *testing.T) { clientID, clientKey := createPeer(t) serverID, serverKey := createPeer(t) - var expectedMuxer string handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport) { clientInsecureConn, serverInsecureConn := connect(t) @@ -122,7 +116,6 @@ func TestHandshakeSucceeds(t *testing.T) { require.Equal(t, serverConn.RemotePeer(), clientID) require.True(t, clientConn.RemotePublicKey().Equals(serverKey.GetPublic()), "server public key mismatch") require.True(t, serverConn.RemotePublicKey().Equals(clientKey.GetPublic()), "client public key mismatch") - require.Equal(t, clientConn.ConnState().EarlyData, expectedMuxer) // exchange some data _, err = serverConn.Write([]byte("foobar")) require.NoError(t, err) @@ -133,27 +126,21 @@ func TestHandshakeSucceeds(t *testing.T) { } // Use standard transports with default TLS configuration - var clientTransport *Transport - var err error - var serverTransport *Transport + clientTransport, err := New(clientKey, nil) + require.NoError(t, err) + serverTransport, err := New(serverKey, nil) + require.NoError(t, err) - for i := 0; i < numMuxers; i++ { - expectedMuxer = expectedMuxers[i] - clientTransport, err = New(clientKey, clientMuxerList[i]) - require.NoError(t, err) - serverTransport, err = New(serverKey, serverMuxerList[i]) - require.NoError(t, err) - t.Run("standard TLS with extension not critical", func(t *testing.T) { - handshake(t, clientTransport, serverTransport) - }) + t.Run("standard TLS with extension not critical", func(t *testing.T) { + handshake(t, clientTransport, serverTransport) + }) - t.Run("standard TLS with extension critical", func(t *testing.T) { - extensionCritical = true - t.Cleanup(func() { extensionCritical = false }) + t.Run("standard TLS with extension critical", func(t *testing.T) { + extensionCritical = true + t.Cleanup(func() { extensionCritical = false }) - handshake(t, clientTransport, serverTransport) - }) - } + handshake(t, clientTransport, serverTransport) + }) // Use transports with custom TLS certificates @@ -177,11 +164,108 @@ func TestHandshakeSucceeds(t *testing.T) { serverTransport.identity, err = NewIdentity(serverKey, WithCertTemplate(serverCertTmpl)) require.NoError(t, err) + t.Run("custom TLS with extension not critical", func(t *testing.T) { + handshake(t, clientTransport, serverTransport) + }) + + t.Run("custom TLS with extension critical", func(t *testing.T) { + extensionCritical = true + t.Cleanup(func() { extensionCritical = false }) + + handshake(t, clientTransport, serverTransport) + }) +} + +func TestHandshakeWithNextProtoSucceeds(t *testing.T) { + var clientMuxerList = [][]protocol.ID{{}, {"muxer1/1.0.0", "muxer2/1.0.1"}, {"muxer1"}, {}, {"muxer1"}} + var serverMuxerList = [][]protocol.ID{{}, {"muxer2/1.0.1", "muxer1/1.0.0"}, {}, {"muxer1"}, {"muxer2"}} + var expectedMuxers = []string{"", "muxer2/1.0.1", "", "", ""} + numMuxers := len(clientMuxerList) + var expectedMuxer string + + clientID, clientKey := createPeer(t) + serverID, serverKey := createPeer(t) + + handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport) { + clientInsecureConn, serverInsecureConn := connect(t) + + serverConnChan := make(chan sec.SecureConn) + go func() { + serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + require.NoError(t, err) + serverConnChan <- serverConn + }() + + clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + require.NoError(t, err) + defer clientConn.Close() + + var serverConn sec.SecureConn + select { + case serverConn = <-serverConnChan: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected the server to accept a connection") + } + defer serverConn.Close() + + require.Equal(t, clientConn.LocalPeer(), clientID) + require.Equal(t, serverConn.LocalPeer(), serverID) + require.True(t, clientConn.LocalPrivateKey().Equals(clientKey), "client private key mismatch") + require.True(t, serverConn.LocalPrivateKey().Equals(serverKey), "server private key mismatch") + require.Equal(t, clientConn.RemotePeer(), serverID) + require.Equal(t, serverConn.RemotePeer(), clientID) + require.True(t, clientConn.RemotePublicKey().Equals(serverKey.GetPublic()), "server public key mismatch") + require.True(t, serverConn.RemotePublicKey().Equals(clientKey.GetPublic()), "client public key mismatch") + require.Equal(t, clientConn.ConnState().NextProto, expectedMuxer) + // exchange some data + _, err = serverConn.Write([]byte("foobar")) + require.NoError(t, err) + b := make([]byte, 6) + _, err = clientConn.Read(b) + require.NoError(t, err) + require.Equal(t, string(b), "foobar") + } + + // Iterate through the NextProto combinations. for i := 0; i < numMuxers; i++ { expectedMuxer = expectedMuxers[i] - clientTransport, err = New(clientKey, clientMuxerList[i]) + clientTransport, err := New(clientKey, clientMuxerList[i]) + require.NoError(t, err) + serverTransport, err := New(serverKey, serverMuxerList[i]) require.NoError(t, err) - serverTransport, err = New(serverKey, serverMuxerList[i]) + + // Use standard transports with default TLS configuration + t.Run("standard TLS with extension not critical", func(t *testing.T) { + handshake(t, clientTransport, serverTransport) + }) + + t.Run("standard TLS with extension critical", func(t *testing.T) { + extensionCritical = true + t.Cleanup(func() { extensionCritical = false }) + + handshake(t, clientTransport, serverTransport) + }) + + // Use transports with custom TLS certificates + + // override client identity to use a custom certificate + clientCertTmlp, err := certTemplate() + require.NoError(t, err) + + clientCertTmlp.Subject.CommonName = "client.test.name" + clientCertTmlp.EmailAddresses = []string{"client-unittest@example.com"} + + clientTransport.identity, err = NewIdentity(clientKey, WithCertTemplate(clientCertTmlp)) + require.NoError(t, err) + + // override server identity to use a custom certificate + serverCertTmpl, err := certTemplate() + require.NoError(t, err) + + serverCertTmpl.Subject.CommonName = "server.test.name" + serverCertTmpl.EmailAddresses = []string{"server-unittest@example.com"} + + serverTransport.identity, err = NewIdentity(serverKey, WithCertTemplate(serverCertTmpl)) require.NoError(t, err) t.Run("custom TLS with extension not critical", func(t *testing.T) { @@ -214,13 +298,72 @@ func TestHandshakeConnectionCancellations(t *testing.T) { _, clientKey := createPeer(t) serverID, serverKey := createPeer(t) + clientTransport, err := New(clientKey, nil) + require.NoError(t, err) + serverTransport, err := New(serverKey, nil) + require.NoError(t, err) + + t.Run("cancel outgoing connection", func(t *testing.T) { + clientInsecureConn, serverInsecureConn := connect(t) + + errChan := make(chan error) + go func() { + conn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, + // and closes the underlying connection when that context is canceled. + // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. + if err == nil { + _, err = conn.Read([]byte{0}) + } + errChan <- err + }() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID) + require.ErrorIs(t, err, context.Canceled) + require.Error(t, <-errChan) + }) + + t.Run("cancel incoming connection", func(t *testing.T) { + clientInsecureConn, serverInsecureConn := connect(t) + + errChan := make(chan error) + go func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + conn, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "") + // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, + // and closes the underlying connection when that context is canceled. + // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. + if err == nil { + _, err = conn.Read([]byte{0}) + } + errChan <- err + }() + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + require.Error(t, err) + require.ErrorIs(t, <-errChan, context.Canceled) + }) +} + +func TestHandshakeConnectionWithNextProtoCancellations(t *testing.T) { + var clientMuxerList = [][]protocol.ID{{}, {"muxer1/1.0.0", "muxer2/1.0.1"}, {"muxer1"}, {}, {"muxer1"}} + var serverMuxerList = [][]protocol.ID{{}, {"muxer2/1.0.1", "muxer1/1.0.0"}, {}, {"muxer1"}, {"muxer2"}} + numMuxers := len(clientMuxerList) + + _, clientKey := createPeer(t) + serverID, serverKey := createPeer(t) + + // Test each combination of NextProto extension. for i := 0; i < numMuxers; i++ { clientTransport, err := New(clientKey, clientMuxerList[i]) require.NoError(t, err) serverTransport, err := New(serverKey, serverMuxerList[i]) require.NoError(t, err) + t.Run("cancel outgoing connection", func(t *testing.T) { clientInsecureConn, serverInsecureConn := connect(t) + errChan := make(chan error) go func() { conn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") diff --git a/p2p/security/tls/transport_test.go-old b/p2p/security/tls/transport_test.go-old new file mode 100644 index 0000000000..86eff902bc --- /dev/null +++ b/p2p/security/tls/transport_test.go-old @@ -0,0 +1,619 @@ +package libp2ptls + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "fmt" + "math/big" + mrand "math/rand" + "net" + "runtime" + "strings" + "testing" + "time" + + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/sec" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var clientMuxerList = [][]string{{}, {"muxer1/1.0.0", "muxer2/1.0.1"}, {"muxer1"}, {}, {"muxer1"}} +var serverMuxerList = [][]string{{}, {"muxer2/1.0.1", "muxer1/1.0.0"}, {}, {"muxer1"}, {"muxer2"}} +var expectedMuxers = []string{"", "muxer2/1.0.1", "", "", ""} + +const numMuxers = 5 + +func createPeer(t *testing.T) (peer.ID, ic.PrivKey) { + var priv ic.PrivKey + var err error + switch mrand.Int() % 4 { + case 0: + priv, _, err = ic.GenerateECDSAKeyPair(rand.Reader) + case 1: + priv, _, err = ic.GenerateRSAKeyPair(2048, rand.Reader) + case 2: + priv, _, err = ic.GenerateEd25519Key(rand.Reader) + case 3: + priv, _, err = ic.GenerateSecp256k1Key(rand.Reader) + } + require.NoError(t, err) + id, err := peer.IDFromPrivateKey(priv) + require.NoError(t, err) + t.Logf("using a %s key: %s", priv.Type(), id.Pretty()) + return id, priv +} + +func connect(t *testing.T) (net.Conn, net.Conn) { + ln, err := net.ListenTCP("tcp", nil) + require.NoError(t, err) + defer ln.Close() + serverConnChan := make(chan *net.TCPConn) + go func() { + conn, err := ln.Accept() + assert.NoError(t, err) + sconn := conn.(*net.TCPConn) + serverConnChan <- sconn + }() + conn, err := net.DialTCP("tcp", nil, ln.Addr().(*net.TCPAddr)) + require.NoError(t, err) + sconn := <-serverConnChan + // On Windows we have to set linger to 0, otherwise we'll occasionally run into errors like the following: + // "connectex: Only one usage of each socket address (protocol/network address/port) is normally permitted." + // See https://github.com/libp2p/go-libp2p/issues/1529. + conn.SetLinger(0) + sconn.SetLinger(0) + t.Cleanup(func() { + conn.Close() + sconn.Close() + }) + return conn, sconn +} + +func isWindowsTCPCloseError(err error) bool { + if runtime.GOOS != "windows" { + return false + } + return strings.Contains(err.Error(), "wsarecv: An existing connection was forcibly closed by the remote host") +} + +func TestHandshakeSucceeds(t *testing.T) { + clientID, clientKey := createPeer(t) + serverID, serverKey := createPeer(t) + var expectedMuxer string + + handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport) { + clientInsecureConn, serverInsecureConn := connect(t) + + serverConnChan := make(chan sec.SecureConn) + go func() { + serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + require.NoError(t, err) + serverConnChan <- serverConn + }() + + clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + require.NoError(t, err) + defer clientConn.Close() + + var serverConn sec.SecureConn + select { + case serverConn = <-serverConnChan: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected the server to accept a connection") + } + defer serverConn.Close() + + require.Equal(t, clientConn.LocalPeer(), clientID) + require.Equal(t, serverConn.LocalPeer(), serverID) + require.True(t, clientConn.LocalPrivateKey().Equals(clientKey), "client private key mismatch") + require.True(t, serverConn.LocalPrivateKey().Equals(serverKey), "server private key mismatch") + require.Equal(t, clientConn.RemotePeer(), serverID) + require.Equal(t, serverConn.RemotePeer(), clientID) + require.True(t, clientConn.RemotePublicKey().Equals(serverKey.GetPublic()), "server public key mismatch") + require.True(t, serverConn.RemotePublicKey().Equals(clientKey.GetPublic()), "client public key mismatch") + require.Equal(t, clientConn.ConnState().EarlyData, expectedMuxer) + // exchange some data + _, err = serverConn.Write([]byte("foobar")) + require.NoError(t, err) + b := make([]byte, 6) + _, err = clientConn.Read(b) + require.NoError(t, err) + require.Equal(t, string(b), "foobar") + } + + // Use standard transports with default TLS configuration + var clientTransport *Transport + var err error + var serverTransport *Transport + + for i := 0; i < numMuxers; i++ { + expectedMuxer = expectedMuxers[i] + clientTransport, err = New(clientKey, clientMuxerList[i]) + require.NoError(t, err) + serverTransport, err = New(serverKey, serverMuxerList[i]) + require.NoError(t, err) + t.Run("standard TLS with extension not critical", func(t *testing.T) { + handshake(t, clientTransport, serverTransport) + }) + + t.Run("standard TLS with extension critical", func(t *testing.T) { + extensionCritical = true + t.Cleanup(func() { extensionCritical = false }) + + handshake(t, clientTransport, serverTransport) + }) + } + + // Use transports with custom TLS certificates + + // override client identity to use a custom certificate + clientCertTmlp, err := certTemplate() + require.NoError(t, err) + + clientCertTmlp.Subject.CommonName = "client.test.name" + clientCertTmlp.EmailAddresses = []string{"client-unittest@example.com"} + + clientTransport.identity, err = NewIdentity(clientKey, WithCertTemplate(clientCertTmlp)) + require.NoError(t, err) + + // override server identity to use a custom certificate + serverCertTmpl, err := certTemplate() + require.NoError(t, err) + + serverCertTmpl.Subject.CommonName = "server.test.name" + serverCertTmpl.EmailAddresses = []string{"server-unittest@example.com"} + + serverTransport.identity, err = NewIdentity(serverKey, WithCertTemplate(serverCertTmpl)) + require.NoError(t, err) + + for i := 0; i < numMuxers; i++ { + expectedMuxer = expectedMuxers[i] + clientTransport, err = New(clientKey, clientMuxerList[i]) + require.NoError(t, err) + serverTransport, err = New(serverKey, serverMuxerList[i]) + require.NoError(t, err) + + t.Run("custom TLS with extension not critical", func(t *testing.T) { + handshake(t, clientTransport, serverTransport) + }) + + t.Run("custom TLS with extension critical", func(t *testing.T) { + extensionCritical = true + t.Cleanup(func() { extensionCritical = false }) + + handshake(t, clientTransport, serverTransport) + }) + } +} + +// crypto/tls' cancellation logic works by spinning up a separate Go routine that watches the ctx. +// If the ctx is canceled, it kills the handshake. +// We need to make sure that the handshake doesn't complete before that Go routine picks up the cancellation. +type delayedConn struct { + net.Conn + delay time.Duration +} + +func (c *delayedConn) Read(b []byte) (int, error) { + time.Sleep(c.delay) + return c.Conn.Read(b) +} + +func TestHandshakeConnectionCancellations(t *testing.T) { + _, clientKey := createPeer(t) + serverID, serverKey := createPeer(t) + + for i := 0; i < numMuxers; i++ { + clientTransport, err := New(clientKey, clientMuxerList[i]) + require.NoError(t, err) + serverTransport, err := New(serverKey, serverMuxerList[i]) + require.NoError(t, err) + t.Run("cancel outgoing connection", func(t *testing.T) { + clientInsecureConn, serverInsecureConn := connect(t) + errChan := make(chan error) + go func() { + conn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, + // and closes the underlying connection when that context is canceled. + // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. + if err == nil { + _, err = conn.Read([]byte{0}) + } + errChan <- err + }() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID) + require.ErrorIs(t, err, context.Canceled) + require.Error(t, <-errChan) + }) + + t.Run("cancel incoming connection", func(t *testing.T) { + clientInsecureConn, serverInsecureConn := connect(t) + + errChan := make(chan error) + go func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + conn, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "") + // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, + // and closes the underlying connection when that context is canceled. + // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. + if err == nil { + _, err = conn.Read([]byte{0}) + } + errChan <- err + }() + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + require.Error(t, err) + require.ErrorIs(t, <-errChan, context.Canceled) + }) + } +} + +func TestPeerIDMismatch(t *testing.T) { + _, clientKey := createPeer(t) + serverID, serverKey := createPeer(t) + + serverTransport, err := New(serverKey, nil) + require.NoError(t, err) + clientTransport, err := New(clientKey, nil) + require.NoError(t, err) + + t.Run("for outgoing connections", func(t *testing.T) { + clientInsecureConn, serverInsecureConn := connect(t) + + errChan := make(chan error) + go func() { + conn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, + // and closes the underlying connection when that context is canceled. + // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. + if err == nil { + _, err = conn.Read([]byte{0}) + } + errChan <- err + }() + + // dial, but expect the wrong peer ID + thirdPartyID, _ := createPeer(t) + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, thirdPartyID) + require.Error(t, err) + require.Contains(t, err.Error(), "peer IDs don't match") + + var serverErr error + select { + case serverErr = <-errChan: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected handshake to return on the server side") + } + require.Error(t, serverErr) + require.Contains(t, serverErr.Error(), "tls: bad certificate") + }) + + t.Run("for incoming connections", func(t *testing.T) { + clientInsecureConn, serverInsecureConn := connect(t) + + errChan := make(chan error) + go func() { + thirdPartyID, _ := createPeer(t) + // expect the wrong peer ID + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, thirdPartyID) + errChan <- err + }() + + conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + require.NoError(t, err) + _, err = conn.Read([]byte{0}) + require.Error(t, err) + require.Contains(t, err.Error(), "tls: bad certificate") + + var serverErr error + select { + case serverErr = <-errChan: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected handshake to return on the server side") + } + require.Error(t, serverErr) + require.Contains(t, serverErr.Error(), "peer IDs don't match") + }) +} + +func TestInvalidCerts(t *testing.T) { + _, clientKey := createPeer(t) + serverID, serverKey := createPeer(t) + + type transform struct { + name string + apply func(*Identity) + checkErr func(*testing.T, error) // the error that the side validating the chain gets + } + + invalidateCertChain := func(identity *Identity) { + switch identity.config.Certificates[0].PrivateKey.(type) { + case *rsa.PrivateKey: + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + identity.config.Certificates[0].PrivateKey = key + case *ecdsa.PrivateKey: + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + identity.config.Certificates[0].PrivateKey = key + default: + t.Fatal("unexpected private key type") + } + } + + twoCerts := func(identity *Identity) { + tmpl := &x509.Certificate{SerialNumber: big.NewInt(1)} + key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + cert1DER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key1.Public(), key1) + require.NoError(t, err) + cert1, err := x509.ParseCertificate(cert1DER) + require.NoError(t, err) + cert2DER, err := x509.CreateCertificate(rand.Reader, tmpl, cert1, key2.Public(), key1) + require.NoError(t, err) + identity.config.Certificates = []tls.Certificate{{ + Certificate: [][]byte{cert2DER, cert1DER}, + PrivateKey: key2, + }} + } + + getCertWithKey := func(key crypto.Signer, tmpl *x509.Certificate) tls.Certificate { + cert, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key) + require.NoError(t, err) + return tls.Certificate{ + Certificate: [][]byte{cert}, + PrivateKey: key, + } + } + + getCert := func(tmpl *x509.Certificate) tls.Certificate { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + return getCertWithKey(key, tmpl) + } + + expiredCert := func(identity *Identity) { + cert := getCert(&x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(-time.Minute), + ExtraExtensions: []pkix.Extension{ + {Id: extensionID, Value: []byte("foobar")}, + }, + }) + identity.config.Certificates = []tls.Certificate{cert} + } + + noKeyExtension := func(identity *Identity) { + cert := getCert(&x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + }) + identity.config.Certificates = []tls.Certificate{cert} + } + + unparseableKeyExtension := func(identity *Identity) { + cert := getCert(&x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + ExtraExtensions: []pkix.Extension{ + {Id: extensionID, Value: []byte("foobar")}, + }, + }) + identity.config.Certificates = []tls.Certificate{cert} + } + + unparseableKey := func(identity *Identity) { + data, err := asn1.Marshal(signedKey{PubKey: []byte("foobar")}) + require.NoError(t, err) + cert := getCert(&x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + ExtraExtensions: []pkix.Extension{ + {Id: extensionID, Value: data}, + }, + }) + identity.config.Certificates = []tls.Certificate{cert} + } + + tooShortSignature := func(identity *Identity) { + key, _, err := ic.GenerateSecp256k1Key(rand.Reader) + require.NoError(t, err) + keyBytes, err := ic.MarshalPublicKey(key.GetPublic()) + require.NoError(t, err) + data, err := asn1.Marshal(signedKey{ + PubKey: keyBytes, + Signature: []byte("foobar"), + }) + require.NoError(t, err) + cert := getCert(&x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + ExtraExtensions: []pkix.Extension{ + {Id: extensionID, Value: data}, + }, + }) + identity.config.Certificates = []tls.Certificate{cert} + } + + invalidSignature := func(identity *Identity) { + key, _, err := ic.GenerateSecp256k1Key(rand.Reader) + require.NoError(t, err) + keyBytes, err := ic.MarshalPublicKey(key.GetPublic()) + require.NoError(t, err) + signature, err := key.Sign([]byte("foobar")) + require.NoError(t, err) + data, err := asn1.Marshal(signedKey{ + PubKey: keyBytes, + Signature: signature, + }) + require.NoError(t, err) + cert := getCert(&x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + ExtraExtensions: []pkix.Extension{ + {Id: extensionID, Value: data}, + }, + }) + identity.config.Certificates = []tls.Certificate{cert} + } + + transforms := []transform{ + { + name: "private key used in the TLS handshake doesn't match the public key in the cert", + apply: invalidateCertChain, + checkErr: func(t *testing.T, err error) { + if err.Error() != "tls: invalid signature by the client certificate: ECDSA verification failure" && + err.Error() != "tls: invalid signature by the server certificate: ECDSA verification failure" { + t.Fatalf("unexpected error message: %s", err) + } + }, + }, + { + name: "certificate chain contains 2 certs", + apply: twoCerts, + checkErr: func(t *testing.T, err error) { + require.EqualError(t, err, "expected one certificates in the chain") + }, + }, + { + name: "cert is expired", + apply: expiredCert, + checkErr: func(t *testing.T, err error) { + require.Contains(t, err.Error(), "certificate has expired or is not yet valid") + }, + }, + { + name: "cert doesn't have the key extension", + apply: noKeyExtension, + checkErr: func(t *testing.T, err error) { + require.EqualError(t, err, "expected certificate to contain the key extension") + }, + }, + { + name: "key extension not parseable", + apply: unparseableKeyExtension, + checkErr: func(t *testing.T, err error) { require.Contains(t, err.Error(), "asn1") }, + }, + { + name: "key protobuf not parseable", + apply: unparseableKey, + checkErr: func(t *testing.T, err error) { + require.Contains(t, err.Error(), "unmarshalling public key failed: proto:") + }, + }, + { + name: "signature is malformed", + apply: tooShortSignature, + checkErr: func(t *testing.T, err error) { + require.Contains(t, err.Error(), "signature verification failed:") + }, + }, + { + name: "signature is invalid", + apply: invalidSignature, + checkErr: func(t *testing.T, err error) { + require.Contains(t, err.Error(), "signature invalid") + }, + }, + } + + for i := range transforms { + tr := transforms[i] + + t.Run(fmt.Sprintf("client offending: %s", tr.name), func(t *testing.T) { + serverTransport, err := New(serverKey, nil) + require.NoError(t, err) + clientTransport, err := New(clientKey, nil) + require.NoError(t, err) + tr.apply(clientTransport.identity) + + clientInsecureConn, serverInsecureConn := connect(t) + + serverErrChan := make(chan error) + go func() { + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + serverErrChan <- err + }() + + conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + require.NoError(t, err) + clientErrChan := make(chan error) + go func() { + _, err := conn.Read([]byte{0}) + clientErrChan <- err + }() + select { + case err := <-clientErrChan: + require.Error(t, err) + if err.Error() != "remote error: tls: error decrypting message" && + err.Error() != "remote error: tls: bad certificate" && + !isWindowsTCPCloseError(err) { + t.Errorf("unexpected error: %s", err.Error()) + } + case <-time.After(250 * time.Millisecond): + t.Fatal("expected the server handshake to return") + } + + select { + case err := <-serverErrChan: + require.Error(t, err) + tr.checkErr(t, err) + case <-time.After(250 * time.Millisecond): + t.Fatal("expected the server handshake to return") + } + }) + + t.Run(fmt.Sprintf("server offending: %s", tr.name), func(t *testing.T) { + serverTransport, err := New(serverKey, nil) + require.NoError(t, err) + tr.apply(serverTransport.identity) + clientTransport, err := New(clientKey, nil) + require.NoError(t, err) + + clientInsecureConn, serverInsecureConn := connect(t) + + errChan := make(chan error) + go func() { + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + errChan <- err + }() + + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + require.Error(t, err) + tr.checkErr(t, err) + + var serverErr error + select { + case serverErr = <-errChan: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected the server handshake to return") + } + require.Error(t, serverErr) + if !isWindowsTCPCloseError(serverErr) { + require.Contains(t, serverErr.Error(), "remote error: tls:") + } + }) + } +} From 8cf8183719556bd8ef0e6aa1a5e609784c6a7edc Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Wed, 28 Sep 2022 15:40:46 -0700 Subject: [PATCH 09/13] Update p2p/net/upgrader/upgrader.go Co-authored-by: Marten Seemann --- p2p/net/upgrader/upgrader.go | 1 - 1 file changed, 1 deletion(-) diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index 28a839be6c..c505b5cd9c 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -79,7 +79,6 @@ type upgrader struct { var _ transport.Upgrader = &upgrader{} func New(secureMuxer sec.SecureMuxer, muxer network.Multiplexer, opts ...Option) (transport.Upgrader, error) { - u := &upgrader{ secure: secureMuxer, muxer: muxer, From 5a5defb2e142c7f3ff28f0a5574b6dabca2f7f9d Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Wed, 28 Sep 2022 16:43:52 -0700 Subject: [PATCH 10/13] clean up accidental checked file. --- p2p/security/tls/transport_test.go-old | 619 ------------------------- 1 file changed, 619 deletions(-) delete mode 100644 p2p/security/tls/transport_test.go-old diff --git a/p2p/security/tls/transport_test.go-old b/p2p/security/tls/transport_test.go-old deleted file mode 100644 index 86eff902bc..0000000000 --- a/p2p/security/tls/transport_test.go-old +++ /dev/null @@ -1,619 +0,0 @@ -package libp2ptls - -import ( - "context" - "crypto" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/asn1" - "fmt" - "math/big" - mrand "math/rand" - "net" - "runtime" - "strings" - "testing" - "time" - - ic "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/sec" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var clientMuxerList = [][]string{{}, {"muxer1/1.0.0", "muxer2/1.0.1"}, {"muxer1"}, {}, {"muxer1"}} -var serverMuxerList = [][]string{{}, {"muxer2/1.0.1", "muxer1/1.0.0"}, {}, {"muxer1"}, {"muxer2"}} -var expectedMuxers = []string{"", "muxer2/1.0.1", "", "", ""} - -const numMuxers = 5 - -func createPeer(t *testing.T) (peer.ID, ic.PrivKey) { - var priv ic.PrivKey - var err error - switch mrand.Int() % 4 { - case 0: - priv, _, err = ic.GenerateECDSAKeyPair(rand.Reader) - case 1: - priv, _, err = ic.GenerateRSAKeyPair(2048, rand.Reader) - case 2: - priv, _, err = ic.GenerateEd25519Key(rand.Reader) - case 3: - priv, _, err = ic.GenerateSecp256k1Key(rand.Reader) - } - require.NoError(t, err) - id, err := peer.IDFromPrivateKey(priv) - require.NoError(t, err) - t.Logf("using a %s key: %s", priv.Type(), id.Pretty()) - return id, priv -} - -func connect(t *testing.T) (net.Conn, net.Conn) { - ln, err := net.ListenTCP("tcp", nil) - require.NoError(t, err) - defer ln.Close() - serverConnChan := make(chan *net.TCPConn) - go func() { - conn, err := ln.Accept() - assert.NoError(t, err) - sconn := conn.(*net.TCPConn) - serverConnChan <- sconn - }() - conn, err := net.DialTCP("tcp", nil, ln.Addr().(*net.TCPAddr)) - require.NoError(t, err) - sconn := <-serverConnChan - // On Windows we have to set linger to 0, otherwise we'll occasionally run into errors like the following: - // "connectex: Only one usage of each socket address (protocol/network address/port) is normally permitted." - // See https://github.com/libp2p/go-libp2p/issues/1529. - conn.SetLinger(0) - sconn.SetLinger(0) - t.Cleanup(func() { - conn.Close() - sconn.Close() - }) - return conn, sconn -} - -func isWindowsTCPCloseError(err error) bool { - if runtime.GOOS != "windows" { - return false - } - return strings.Contains(err.Error(), "wsarecv: An existing connection was forcibly closed by the remote host") -} - -func TestHandshakeSucceeds(t *testing.T) { - clientID, clientKey := createPeer(t) - serverID, serverKey := createPeer(t) - var expectedMuxer string - - handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport) { - clientInsecureConn, serverInsecureConn := connect(t) - - serverConnChan := make(chan sec.SecureConn) - go func() { - serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") - require.NoError(t, err) - serverConnChan <- serverConn - }() - - clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - require.NoError(t, err) - defer clientConn.Close() - - var serverConn sec.SecureConn - select { - case serverConn = <-serverConnChan: - case <-time.After(250 * time.Millisecond): - t.Fatal("expected the server to accept a connection") - } - defer serverConn.Close() - - require.Equal(t, clientConn.LocalPeer(), clientID) - require.Equal(t, serverConn.LocalPeer(), serverID) - require.True(t, clientConn.LocalPrivateKey().Equals(clientKey), "client private key mismatch") - require.True(t, serverConn.LocalPrivateKey().Equals(serverKey), "server private key mismatch") - require.Equal(t, clientConn.RemotePeer(), serverID) - require.Equal(t, serverConn.RemotePeer(), clientID) - require.True(t, clientConn.RemotePublicKey().Equals(serverKey.GetPublic()), "server public key mismatch") - require.True(t, serverConn.RemotePublicKey().Equals(clientKey.GetPublic()), "client public key mismatch") - require.Equal(t, clientConn.ConnState().EarlyData, expectedMuxer) - // exchange some data - _, err = serverConn.Write([]byte("foobar")) - require.NoError(t, err) - b := make([]byte, 6) - _, err = clientConn.Read(b) - require.NoError(t, err) - require.Equal(t, string(b), "foobar") - } - - // Use standard transports with default TLS configuration - var clientTransport *Transport - var err error - var serverTransport *Transport - - for i := 0; i < numMuxers; i++ { - expectedMuxer = expectedMuxers[i] - clientTransport, err = New(clientKey, clientMuxerList[i]) - require.NoError(t, err) - serverTransport, err = New(serverKey, serverMuxerList[i]) - require.NoError(t, err) - t.Run("standard TLS with extension not critical", func(t *testing.T) { - handshake(t, clientTransport, serverTransport) - }) - - t.Run("standard TLS with extension critical", func(t *testing.T) { - extensionCritical = true - t.Cleanup(func() { extensionCritical = false }) - - handshake(t, clientTransport, serverTransport) - }) - } - - // Use transports with custom TLS certificates - - // override client identity to use a custom certificate - clientCertTmlp, err := certTemplate() - require.NoError(t, err) - - clientCertTmlp.Subject.CommonName = "client.test.name" - clientCertTmlp.EmailAddresses = []string{"client-unittest@example.com"} - - clientTransport.identity, err = NewIdentity(clientKey, WithCertTemplate(clientCertTmlp)) - require.NoError(t, err) - - // override server identity to use a custom certificate - serverCertTmpl, err := certTemplate() - require.NoError(t, err) - - serverCertTmpl.Subject.CommonName = "server.test.name" - serverCertTmpl.EmailAddresses = []string{"server-unittest@example.com"} - - serverTransport.identity, err = NewIdentity(serverKey, WithCertTemplate(serverCertTmpl)) - require.NoError(t, err) - - for i := 0; i < numMuxers; i++ { - expectedMuxer = expectedMuxers[i] - clientTransport, err = New(clientKey, clientMuxerList[i]) - require.NoError(t, err) - serverTransport, err = New(serverKey, serverMuxerList[i]) - require.NoError(t, err) - - t.Run("custom TLS with extension not critical", func(t *testing.T) { - handshake(t, clientTransport, serverTransport) - }) - - t.Run("custom TLS with extension critical", func(t *testing.T) { - extensionCritical = true - t.Cleanup(func() { extensionCritical = false }) - - handshake(t, clientTransport, serverTransport) - }) - } -} - -// crypto/tls' cancellation logic works by spinning up a separate Go routine that watches the ctx. -// If the ctx is canceled, it kills the handshake. -// We need to make sure that the handshake doesn't complete before that Go routine picks up the cancellation. -type delayedConn struct { - net.Conn - delay time.Duration -} - -func (c *delayedConn) Read(b []byte) (int, error) { - time.Sleep(c.delay) - return c.Conn.Read(b) -} - -func TestHandshakeConnectionCancellations(t *testing.T) { - _, clientKey := createPeer(t) - serverID, serverKey := createPeer(t) - - for i := 0; i < numMuxers; i++ { - clientTransport, err := New(clientKey, clientMuxerList[i]) - require.NoError(t, err) - serverTransport, err := New(serverKey, serverMuxerList[i]) - require.NoError(t, err) - t.Run("cancel outgoing connection", func(t *testing.T) { - clientInsecureConn, serverInsecureConn := connect(t) - errChan := make(chan error) - go func() { - conn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") - // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, - // and closes the underlying connection when that context is canceled. - // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. - if err == nil { - _, err = conn.Read([]byte{0}) - } - errChan <- err - }() - ctx, cancel := context.WithCancel(context.Background()) - cancel() - _, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID) - require.ErrorIs(t, err, context.Canceled) - require.Error(t, <-errChan) - }) - - t.Run("cancel incoming connection", func(t *testing.T) { - clientInsecureConn, serverInsecureConn := connect(t) - - errChan := make(chan error) - go func() { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "") - // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, - // and closes the underlying connection when that context is canceled. - // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. - if err == nil { - _, err = conn.Read([]byte{0}) - } - errChan <- err - }() - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - require.Error(t, err) - require.ErrorIs(t, <-errChan, context.Canceled) - }) - } -} - -func TestPeerIDMismatch(t *testing.T) { - _, clientKey := createPeer(t) - serverID, serverKey := createPeer(t) - - serverTransport, err := New(serverKey, nil) - require.NoError(t, err) - clientTransport, err := New(clientKey, nil) - require.NoError(t, err) - - t.Run("for outgoing connections", func(t *testing.T) { - clientInsecureConn, serverInsecureConn := connect(t) - - errChan := make(chan error) - go func() { - conn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") - // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, - // and closes the underlying connection when that context is canceled. - // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. - if err == nil { - _, err = conn.Read([]byte{0}) - } - errChan <- err - }() - - // dial, but expect the wrong peer ID - thirdPartyID, _ := createPeer(t) - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, thirdPartyID) - require.Error(t, err) - require.Contains(t, err.Error(), "peer IDs don't match") - - var serverErr error - select { - case serverErr = <-errChan: - case <-time.After(250 * time.Millisecond): - t.Fatal("expected handshake to return on the server side") - } - require.Error(t, serverErr) - require.Contains(t, serverErr.Error(), "tls: bad certificate") - }) - - t.Run("for incoming connections", func(t *testing.T) { - clientInsecureConn, serverInsecureConn := connect(t) - - errChan := make(chan error) - go func() { - thirdPartyID, _ := createPeer(t) - // expect the wrong peer ID - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, thirdPartyID) - errChan <- err - }() - - conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - require.NoError(t, err) - _, err = conn.Read([]byte{0}) - require.Error(t, err) - require.Contains(t, err.Error(), "tls: bad certificate") - - var serverErr error - select { - case serverErr = <-errChan: - case <-time.After(250 * time.Millisecond): - t.Fatal("expected handshake to return on the server side") - } - require.Error(t, serverErr) - require.Contains(t, serverErr.Error(), "peer IDs don't match") - }) -} - -func TestInvalidCerts(t *testing.T) { - _, clientKey := createPeer(t) - serverID, serverKey := createPeer(t) - - type transform struct { - name string - apply func(*Identity) - checkErr func(*testing.T, error) // the error that the side validating the chain gets - } - - invalidateCertChain := func(identity *Identity) { - switch identity.config.Certificates[0].PrivateKey.(type) { - case *rsa.PrivateKey: - key, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - identity.config.Certificates[0].PrivateKey = key - case *ecdsa.PrivateKey: - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - identity.config.Certificates[0].PrivateKey = key - default: - t.Fatal("unexpected private key type") - } - } - - twoCerts := func(identity *Identity) { - tmpl := &x509.Certificate{SerialNumber: big.NewInt(1)} - key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - cert1DER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key1.Public(), key1) - require.NoError(t, err) - cert1, err := x509.ParseCertificate(cert1DER) - require.NoError(t, err) - cert2DER, err := x509.CreateCertificate(rand.Reader, tmpl, cert1, key2.Public(), key1) - require.NoError(t, err) - identity.config.Certificates = []tls.Certificate{{ - Certificate: [][]byte{cert2DER, cert1DER}, - PrivateKey: key2, - }} - } - - getCertWithKey := func(key crypto.Signer, tmpl *x509.Certificate) tls.Certificate { - cert, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key) - require.NoError(t, err) - return tls.Certificate{ - Certificate: [][]byte{cert}, - PrivateKey: key, - } - } - - getCert := func(tmpl *x509.Certificate) tls.Certificate { - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - return getCertWithKey(key, tmpl) - } - - expiredCert := func(identity *Identity) { - cert := getCert(&x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(-time.Minute), - ExtraExtensions: []pkix.Extension{ - {Id: extensionID, Value: []byte("foobar")}, - }, - }) - identity.config.Certificates = []tls.Certificate{cert} - } - - noKeyExtension := func(identity *Identity) { - cert := getCert(&x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - }) - identity.config.Certificates = []tls.Certificate{cert} - } - - unparseableKeyExtension := func(identity *Identity) { - cert := getCert(&x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - ExtraExtensions: []pkix.Extension{ - {Id: extensionID, Value: []byte("foobar")}, - }, - }) - identity.config.Certificates = []tls.Certificate{cert} - } - - unparseableKey := func(identity *Identity) { - data, err := asn1.Marshal(signedKey{PubKey: []byte("foobar")}) - require.NoError(t, err) - cert := getCert(&x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - ExtraExtensions: []pkix.Extension{ - {Id: extensionID, Value: data}, - }, - }) - identity.config.Certificates = []tls.Certificate{cert} - } - - tooShortSignature := func(identity *Identity) { - key, _, err := ic.GenerateSecp256k1Key(rand.Reader) - require.NoError(t, err) - keyBytes, err := ic.MarshalPublicKey(key.GetPublic()) - require.NoError(t, err) - data, err := asn1.Marshal(signedKey{ - PubKey: keyBytes, - Signature: []byte("foobar"), - }) - require.NoError(t, err) - cert := getCert(&x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - ExtraExtensions: []pkix.Extension{ - {Id: extensionID, Value: data}, - }, - }) - identity.config.Certificates = []tls.Certificate{cert} - } - - invalidSignature := func(identity *Identity) { - key, _, err := ic.GenerateSecp256k1Key(rand.Reader) - require.NoError(t, err) - keyBytes, err := ic.MarshalPublicKey(key.GetPublic()) - require.NoError(t, err) - signature, err := key.Sign([]byte("foobar")) - require.NoError(t, err) - data, err := asn1.Marshal(signedKey{ - PubKey: keyBytes, - Signature: signature, - }) - require.NoError(t, err) - cert := getCert(&x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - ExtraExtensions: []pkix.Extension{ - {Id: extensionID, Value: data}, - }, - }) - identity.config.Certificates = []tls.Certificate{cert} - } - - transforms := []transform{ - { - name: "private key used in the TLS handshake doesn't match the public key in the cert", - apply: invalidateCertChain, - checkErr: func(t *testing.T, err error) { - if err.Error() != "tls: invalid signature by the client certificate: ECDSA verification failure" && - err.Error() != "tls: invalid signature by the server certificate: ECDSA verification failure" { - t.Fatalf("unexpected error message: %s", err) - } - }, - }, - { - name: "certificate chain contains 2 certs", - apply: twoCerts, - checkErr: func(t *testing.T, err error) { - require.EqualError(t, err, "expected one certificates in the chain") - }, - }, - { - name: "cert is expired", - apply: expiredCert, - checkErr: func(t *testing.T, err error) { - require.Contains(t, err.Error(), "certificate has expired or is not yet valid") - }, - }, - { - name: "cert doesn't have the key extension", - apply: noKeyExtension, - checkErr: func(t *testing.T, err error) { - require.EqualError(t, err, "expected certificate to contain the key extension") - }, - }, - { - name: "key extension not parseable", - apply: unparseableKeyExtension, - checkErr: func(t *testing.T, err error) { require.Contains(t, err.Error(), "asn1") }, - }, - { - name: "key protobuf not parseable", - apply: unparseableKey, - checkErr: func(t *testing.T, err error) { - require.Contains(t, err.Error(), "unmarshalling public key failed: proto:") - }, - }, - { - name: "signature is malformed", - apply: tooShortSignature, - checkErr: func(t *testing.T, err error) { - require.Contains(t, err.Error(), "signature verification failed:") - }, - }, - { - name: "signature is invalid", - apply: invalidSignature, - checkErr: func(t *testing.T, err error) { - require.Contains(t, err.Error(), "signature invalid") - }, - }, - } - - for i := range transforms { - tr := transforms[i] - - t.Run(fmt.Sprintf("client offending: %s", tr.name), func(t *testing.T) { - serverTransport, err := New(serverKey, nil) - require.NoError(t, err) - clientTransport, err := New(clientKey, nil) - require.NoError(t, err) - tr.apply(clientTransport.identity) - - clientInsecureConn, serverInsecureConn := connect(t) - - serverErrChan := make(chan error) - go func() { - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") - serverErrChan <- err - }() - - conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - require.NoError(t, err) - clientErrChan := make(chan error) - go func() { - _, err := conn.Read([]byte{0}) - clientErrChan <- err - }() - select { - case err := <-clientErrChan: - require.Error(t, err) - if err.Error() != "remote error: tls: error decrypting message" && - err.Error() != "remote error: tls: bad certificate" && - !isWindowsTCPCloseError(err) { - t.Errorf("unexpected error: %s", err.Error()) - } - case <-time.After(250 * time.Millisecond): - t.Fatal("expected the server handshake to return") - } - - select { - case err := <-serverErrChan: - require.Error(t, err) - tr.checkErr(t, err) - case <-time.After(250 * time.Millisecond): - t.Fatal("expected the server handshake to return") - } - }) - - t.Run(fmt.Sprintf("server offending: %s", tr.name), func(t *testing.T) { - serverTransport, err := New(serverKey, nil) - require.NoError(t, err) - tr.apply(serverTransport.identity) - clientTransport, err := New(clientKey, nil) - require.NoError(t, err) - - clientInsecureConn, serverInsecureConn := connect(t) - - errChan := make(chan error) - go func() { - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") - errChan <- err - }() - - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - require.Error(t, err) - tr.checkErr(t, err) - - var serverErr error - select { - case serverErr = <-errChan: - case <-time.After(250 * time.Millisecond): - t.Fatal("expected the server handshake to return") - } - require.Error(t, serverErr) - if !isWindowsTCPCloseError(serverErr) { - require.Contains(t, serverErr.Error(), "remote error: tls:") - } - }) - } -} From 7f8fc8924b84172281245d204a1b87cf81dc1748 Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Thu, 29 Sep 2022 11:57:06 -0700 Subject: [PATCH 11/13] Review points round 2 --- config/security.go | 2 +- core/network/conn.go | 5 +- p2p/net/upgrader/upgrader.go | 4 +- p2p/security/noise/session.go | 5 +- p2p/security/tls/transport.go | 6 +- p2p/security/tls/transport_test.go | 130 +++++------------------------ 6 files changed, 34 insertions(+), 118 deletions(-) diff --git a/config/security.go b/config/security.go index edc4a38b16..6f15b63462 100644 --- a/config/security.go +++ b/config/security.go @@ -68,7 +68,7 @@ func makeSecurityMuxer(h host.Host, tpts []MsSecC, muxers []MsMuxC) (sec.SecureM } muxIds := make([]protocol.ID, 0, len(muxers)) for _, muxc := range muxers { - muxIds = append(muxIds, (protocol.ID)(muxc.ID)) + muxIds = append(muxIds, protocol.ID(muxc.ID)) } for _, tptC := range tpts { tpt, err := tptC.SecC(h, muxIds) diff --git a/core/network/conn.go b/core/network/conn.go index e00ad59f83..18414b062c 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -36,8 +36,9 @@ type Conn interface { // ConnectionState holds extra information releated to the ConnSecurity entity. type ConnectionState struct { - // Early data result derived from security protocol handshake. - // For example, Noise handshake payload or TLS/ALPN negotiation. + // The next protocol used for stream muxer selection. This is derived from + // security protocol handshake, for example, Noise handshake payload or + // TLS/ALPN negotiation. NextProto string } diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index c505b5cd9c..2ef60b82ac 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -13,9 +13,9 @@ import ( ipnet "github.com/libp2p/go-libp2p/core/pnet" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/transport" + msmux "github.com/libp2p/go-libp2p/p2p/muxer/muxer-multistream" "github.com/libp2p/go-libp2p/p2p/net/pnet" - msmux "github.com/libp2p/go-libp2p/p2p/muxer/muxer-multistream" manet "github.com/multiformats/go-multiaddr/net" ) @@ -202,7 +202,7 @@ func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server b if ok && len(muxerSelected) > 0 { tpt, ok := msmuxer.GetTransportByKey(muxerSelected) if !ok { - return nil, fmt.Errorf("selected a muxer we don't have a transport for") + return nil, fmt.Errorf("selected a muxer we don't know: %s", muxerSelected) } return tpt.NewConn(conn, server, scope) diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index 4bdcc3710d..f1286b9ffb 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -40,9 +40,6 @@ type secureSession struct { prologue []byte initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler - - // Next protocol derived from handshaking. It is empty if not supported. - nextProto string } // newSecureSession creates a Noise session over the given insecureConn Conn, using @@ -113,7 +110,7 @@ func (s *secureSession) RemotePublicKey() crypto.PubKey { } func (s *secureSession) ConnState() network.ConnectionState { - return network.ConnectionState{NextProto: s.nextProto} + return network.ConnectionState{} } func (s *secureSession) SetDeadline(t time.Time) error { diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index 5ae91be844..b5549c78db 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -59,7 +59,7 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer config, keyCh := t.identity.ConfigForPeer(p) muxers := make([]string, 0, len(t.muxers)) for _, muxer := range t.muxers { - muxers = append(muxers, (string)(muxer)) + muxers = append(muxers, string(muxer)) } // Prepend the prefered muxers list to TLS config. config.NextProtos = append(muxers, config.NextProtos...) @@ -130,6 +130,10 @@ func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.Se } nextProto := tlsConn.ConnectionState().NegotiatedProtocol + // The special nextProto ID "libp2p" was used in previous versions upto + // v0.23.2 as the TLS ALPN extension field. If we see this special ID + // selected, that means we are handshaking with an old version of libp2p. + // In this case return empty nextProto to indicate no ALPN is selected. if nextProto == "libp2p" { nextProto = "" } diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index 0921a501bb..a7d4eacd78 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -176,11 +176,25 @@ func TestHandshakeSucceeds(t *testing.T) { }) } +type testcase struct { + clientProtos []protocol.ID + serverProtos []protocol.ID + expectedResult string +} + func TestHandshakeWithNextProtoSucceeds(t *testing.T) { - var clientMuxerList = [][]protocol.ID{{}, {"muxer1/1.0.0", "muxer2/1.0.1"}, {"muxer1"}, {}, {"muxer1"}} - var serverMuxerList = [][]protocol.ID{{}, {"muxer2/1.0.1", "muxer1/1.0.0"}, {}, {"muxer1"}, {"muxer2"}} - var expectedMuxers = []string{"", "muxer2/1.0.1", "", "", ""} - numMuxers := len(clientMuxerList) + + tests := []testcase{ + {clientProtos: nil, serverProtos: nil, expectedResult: ""}, + {[]protocol.ID{"muxer1/1.0.0", "muxer2/1.0.1"}, []protocol.ID{"muxer2/1.0.1", "muxer1/1.0.0"}, "muxer2/1.0.1"}, + {[]protocol.ID{"muxer1/1.0.0", "muxer2/1.0.1", "libp2p"}, []protocol.ID{"muxer2/1.0.1", "muxer1/1.0.0", "libp2p"}, "muxer2/1.0.1"}, + {[]protocol.ID{"muxer1/1.0.0", "libp2p"}, []protocol.ID{"libp2p"}, ""}, + {[]protocol.ID{"libp2p"}, []protocol.ID{"libp2p"}, ""}, + {[]protocol.ID{"muxer1"}, []protocol.ID{}, ""}, + {[]protocol.ID{}, []protocol.ID{"muxer1"}, ""}, + {[]protocol.ID{"muxer2"}, []protocol.ID{"muxer1"}, ""}, + } + numMuxers := len(tests) var expectedMuxer string clientID, clientKey := createPeer(t) @@ -228,54 +242,13 @@ func TestHandshakeWithNextProtoSucceeds(t *testing.T) { // Iterate through the NextProto combinations. for i := 0; i < numMuxers; i++ { - expectedMuxer = expectedMuxers[i] - clientTransport, err := New(clientKey, clientMuxerList[i]) + expectedMuxer = tests[i].expectedResult + clientTransport, err := New(clientKey, tests[i].clientProtos) require.NoError(t, err) - serverTransport, err := New(serverKey, serverMuxerList[i]) + serverTransport, err := New(serverKey, tests[i].serverProtos) require.NoError(t, err) - // Use standard transports with default TLS configuration - t.Run("standard TLS with extension not critical", func(t *testing.T) { - handshake(t, clientTransport, serverTransport) - }) - - t.Run("standard TLS with extension critical", func(t *testing.T) { - extensionCritical = true - t.Cleanup(func() { extensionCritical = false }) - - handshake(t, clientTransport, serverTransport) - }) - - // Use transports with custom TLS certificates - - // override client identity to use a custom certificate - clientCertTmlp, err := certTemplate() - require.NoError(t, err) - - clientCertTmlp.Subject.CommonName = "client.test.name" - clientCertTmlp.EmailAddresses = []string{"client-unittest@example.com"} - - clientTransport.identity, err = NewIdentity(clientKey, WithCertTemplate(clientCertTmlp)) - require.NoError(t, err) - - // override server identity to use a custom certificate - serverCertTmpl, err := certTemplate() - require.NoError(t, err) - - serverCertTmpl.Subject.CommonName = "server.test.name" - serverCertTmpl.EmailAddresses = []string{"server-unittest@example.com"} - - serverTransport.identity, err = NewIdentity(serverKey, WithCertTemplate(serverCertTmpl)) - require.NoError(t, err) - - t.Run("custom TLS with extension not critical", func(t *testing.T) { - handshake(t, clientTransport, serverTransport) - }) - - t.Run("custom TLS with extension critical", func(t *testing.T) { - extensionCritical = true - t.Cleanup(func() { extensionCritical = false }) - + t.Run("TLS handshake with ALPN extension", func(t *testing.T) { handshake(t, clientTransport, serverTransport) }) } @@ -346,65 +319,6 @@ func TestHandshakeConnectionCancellations(t *testing.T) { }) } -func TestHandshakeConnectionWithNextProtoCancellations(t *testing.T) { - var clientMuxerList = [][]protocol.ID{{}, {"muxer1/1.0.0", "muxer2/1.0.1"}, {"muxer1"}, {}, {"muxer1"}} - var serverMuxerList = [][]protocol.ID{{}, {"muxer2/1.0.1", "muxer1/1.0.0"}, {}, {"muxer1"}, {"muxer2"}} - numMuxers := len(clientMuxerList) - - _, clientKey := createPeer(t) - serverID, serverKey := createPeer(t) - - // Test each combination of NextProto extension. - for i := 0; i < numMuxers; i++ { - clientTransport, err := New(clientKey, clientMuxerList[i]) - require.NoError(t, err) - serverTransport, err := New(serverKey, serverMuxerList[i]) - require.NoError(t, err) - - t.Run("cancel outgoing connection", func(t *testing.T) { - clientInsecureConn, serverInsecureConn := connect(t) - - errChan := make(chan error) - go func() { - conn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") - // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, - // and closes the underlying connection when that context is canceled. - // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. - if err == nil { - _, err = conn.Read([]byte{0}) - } - errChan <- err - }() - ctx, cancel := context.WithCancel(context.Background()) - cancel() - _, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID) - require.ErrorIs(t, err, context.Canceled) - require.Error(t, <-errChan) - }) - - t.Run("cancel incoming connection", func(t *testing.T) { - clientInsecureConn, serverInsecureConn := connect(t) - - errChan := make(chan error) - go func() { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn, err := serverTransport.SecureInbound(ctx, &delayedConn{Conn: serverInsecureConn, delay: 5 * time.Millisecond}, "") - // crypto/tls' context handling works by spinning up a separate Go routine that watches the context, - // and closes the underlying connection when that context is canceled. - // It is therefore not guaranteed (but very likely) that this happens _during_ the TLS handshake. - if err == nil { - _, err = conn.Read([]byte{0}) - } - errChan <- err - }() - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - require.Error(t, err) - require.ErrorIs(t, <-errChan, context.Canceled) - }) - } -} - func TestPeerIDMismatch(t *testing.T) { _, clientKey := createPeer(t) serverID, serverKey := createPeer(t) From 17500fd88afaa3e19a1a8da6044b6149d8750033 Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Fri, 30 Sep 2022 09:47:34 -0700 Subject: [PATCH 12/13] Address some go nit points --- p2p/security/tls/transport.go | 9 +++++---- p2p/security/tls/transport_test.go | 9 ++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index b5549c78db..695f648465 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -130,10 +130,11 @@ func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.Se } nextProto := tlsConn.ConnectionState().NegotiatedProtocol - // The special nextProto ID "libp2p" was used in previous versions upto - // v0.23.2 as the TLS ALPN extension field. If we see this special ID - // selected, that means we are handshaking with an old version of libp2p. - // In this case return empty nextProto to indicate no ALPN is selected. + // The special ALPN extension value "libp2p" is used by libp2p versions + // that don't support early muxer negotiation. If we see this sepcial + // value selected, that means we are handshaking with a version that does + // not support early muxer negotiation. In this case return empty nextProto + // to indicate no muxer is selected. if nextProto == "libp2p" { nextProto = "" } diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index a7d4eacd78..d50784bd83 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -194,7 +194,6 @@ func TestHandshakeWithNextProtoSucceeds(t *testing.T) { {[]protocol.ID{}, []protocol.ID{"muxer1"}, ""}, {[]protocol.ID{"muxer2"}, []protocol.ID{"muxer1"}, ""}, } - numMuxers := len(tests) var expectedMuxer string clientID, clientKey := createPeer(t) @@ -241,11 +240,11 @@ func TestHandshakeWithNextProtoSucceeds(t *testing.T) { } // Iterate through the NextProto combinations. - for i := 0; i < numMuxers; i++ { - expectedMuxer = tests[i].expectedResult - clientTransport, err := New(clientKey, tests[i].clientProtos) + for _, test := range tests { + expectedMuxer = test.expectedResult + clientTransport, err := New(clientKey, test.clientProtos) require.NoError(t, err) - serverTransport, err := New(serverKey, tests[i].serverProtos) + serverTransport, err := New(serverKey, test.serverProtos) require.NoError(t, err) t.Run("TLS handshake with ALPN extension", func(t *testing.T) { From a704411453604d3b3edfb2a7250fc2952004953a Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Fri, 7 Oct 2022 09:40:54 -0700 Subject: [PATCH 13/13] Update tls transport test to address review points --- p2p/security/tls/transport_test.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index d50784bd83..59fa8bdae8 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -194,12 +194,11 @@ func TestHandshakeWithNextProtoSucceeds(t *testing.T) { {[]protocol.ID{}, []protocol.ID{"muxer1"}, ""}, {[]protocol.ID{"muxer2"}, []protocol.ID{"muxer1"}, ""}, } - var expectedMuxer string clientID, clientKey := createPeer(t) serverID, serverKey := createPeer(t) - handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport) { + handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport, expectedMuxer string) { clientInsecureConn, serverInsecureConn := connect(t) serverConnChan := make(chan sec.SecureConn) @@ -241,14 +240,13 @@ func TestHandshakeWithNextProtoSucceeds(t *testing.T) { // Iterate through the NextProto combinations. for _, test := range tests { - expectedMuxer = test.expectedResult clientTransport, err := New(clientKey, test.clientProtos) require.NoError(t, err) serverTransport, err := New(serverKey, test.serverProtos) require.NoError(t, err) t.Run("TLS handshake with ALPN extension", func(t *testing.T) { - handshake(t, clientTransport, serverTransport) + handshake(t, clientTransport, serverTransport, test.expectedResult) }) } }