Skip to content

Commit

Permalink
webtransport: have the server send the certificates
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Sep 19, 2022
1 parent c1bdab4 commit 6795012
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 54 deletions.
42 changes: 32 additions & 10 deletions p2p/transport/webtransport/cert_manager.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package libp2pwebtransport

import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"fmt"
"sync"
"time"

pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"

"github.com/benbjohnson/clock"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
Expand Down Expand Up @@ -54,6 +55,8 @@ type certManager struct {
currentConfig *certConfig
nextConfig *certConfig // nil until we have passed half the certValidity of the current config
addrComp ma.Multiaddr

protobuf []byte
}

func newCertManager(clock clock.Clock) (*certManager, error) {
Expand Down Expand Up @@ -88,6 +91,9 @@ func (m *certManager) rollConfig() error {
m.lastConfig = m.currentConfig
m.currentConfig = m.nextConfig
m.nextConfig = c
if err := m.cacheProtobuf(); err != nil {
return err
}
return m.cacheAddrComponent()
}

Expand Down Expand Up @@ -131,17 +137,33 @@ func (m *certManager) AddrComponent() ma.Multiaddr {
return m.addrComp
}

func (m *certManager) Verify(hashes []multihash.DecodedMultihash) error {
for _, h := range hashes {
if h.Code != multihash.SHA2_256 {
return fmt.Errorf("expected SHA256 hash, got %d", h.Code)
}
if !bytes.Equal(h.Digest, m.currentConfig.sha256[:]) &&
(m.nextConfig == nil || !bytes.Equal(h.Digest, m.nextConfig.sha256[:])) &&
(m.lastConfig == nil || !bytes.Equal(h.Digest, m.lastConfig.sha256[:])) {
return fmt.Errorf("found unexpected hash: %+x", h.Digest)
func (m *certManager) Protobuf() []byte {
return m.protobuf
}

func (m *certManager) cacheProtobuf() error {
hashes := make([][32]byte, 0, 3)
if m.lastConfig != nil {
hashes = append(hashes, m.lastConfig.sha256)
}
hashes = append(hashes, m.currentConfig.sha256)
if m.nextConfig != nil {
hashes = append(hashes, m.nextConfig.sha256)
}

msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(hashes))}
for _, certHash := range hashes {
h, err := multihash.Encode(certHash[:], multihash.SHA2_256)
if err != nil {
return fmt.Errorf("failed to encode certificate hash: %w", err)
}
msg.CertHashes = append(msg.CertHashes, h)
}
msgBytes, err := msg.Marshal()
if err != nil {
return fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
}
m.protobuf = msgBytes
return nil
}

Expand Down
43 changes: 15 additions & 28 deletions p2p/transport/webtransport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ import (
"net/http"
"time"

pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"

"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/network"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"

"github.com/lucas-clemente/quic-go/http3"
"github.com/marten-seemann/webtransport-go"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multihash"
)

var errClosed = errors.New("closed")
Expand Down Expand Up @@ -197,7 +197,19 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*
if err != nil {
return nil, err
}
n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataReceiver(l.checkEarlyData)))
var earlyData []byte
if l.isStaticTLSConf {
var msg pb.WebTransport
var err error
earlyData, err = msg.Marshal()
if err != nil {
return nil, err
}
} else {
earlyData = l.certManager.Protobuf()
}

n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataSender(earlyData)))
if err != nil {
return nil, fmt.Errorf("failed to initialize Noise session: %w", err)
}
Expand All @@ -212,31 +224,6 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*
}, nil
}

func (l *listener) checkEarlyData(b []byte) error {
var msg pb.WebTransport
if err := msg.Unmarshal(b); err != nil {
fmt.Println(1)
return fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
}

if l.isStaticTLSConf {
if len(msg.CertHashes) > 0 {
return errors.New("using static TLS config, didn't expect any certificate hashes")
}
return nil
}

hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
for _, h := range msg.CertHashes {
dh, err := multihash.Decode(h)
if err != nil {
return fmt.Errorf("failed to decode hash: %w", err)
}
hashes = append(hashes, *dh)
}
return l.certManager.Verify(hashes)
}

func (l *listener) Addr() net.Addr {
return l.addr
}
Expand Down
54 changes: 43 additions & 11 deletions p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package libp2pwebtransport

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"sync"
Expand Down Expand Up @@ -196,32 +198,62 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p

// Now run a Noise handshake (using early data) and send all the certificate hashes that we would have accepted.
// The server will verify that it advertised all of these certificate hashes.
msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(certHashes))}
for _, certHash := range certHashes {
h, err := multihash.Encode(certHash.Digest, certHash.Code)
var verified bool
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(func(b []byte) error {
decodedCertHashes, err := decodeCertHashesFromProtobuf(b)
if err != nil {
return nil, fmt.Errorf("failed to encode certificate hash: %w", err)
return err
}
msg.CertHashes = append(msg.CertHashes, h)
}
msgBytes, err := msg.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
}
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes), nil))
for _, sent := range certHashes {
var found bool
for _, rcvd := range decodedCertHashes {
if sent.Code == rcvd.Code && bytes.Equal(sent.Digest, rcvd.Digest) {
found = true
break
}
}
if !found {
return fmt.Errorf("missing cert hash: %v", sent)
}
}
verified = true
return nil
}), nil))
if err != nil {
return nil, fmt.Errorf("failed to create Noise transport: %w", err)
}
c, err := n.SecureOutbound(ctx, &webtransportStream{Stream: str, wsess: sess}, p)
if err != nil {
return nil, err
}
// The Noise handshake _should_ guarantee that our verification callback is called.
// Double-check just in case.
if !verified {
return nil, errors.New("didn't verify")
}
return &connSecurityMultiaddrs{
ConnSecurity: c,
ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote},
}, nil
}

func decodeCertHashesFromProtobuf(b []byte) ([]multihash.DecodedMultihash, error) {
var msg pb.WebTransport
if err := msg.Unmarshal(b); err != nil {
return nil, fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
}

hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
for _, h := range msg.CertHashes {
dh, err := multihash.Decode(h)
if err != nil {
return nil, fmt.Errorf("failed to decode hash: %w", err)
}
hashes = append(hashes, *dh)
}
return hashes, nil
}

func (t *transport) CanDial(addr ma.Multiaddr) bool {
var numHashes int
ma.ForEach(addr, func(c ma.Component) bool {
Expand Down
7 changes: 2 additions & 5 deletions p2p/transport/webtransport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,8 @@ func TestHashVerification(t *testing.T) {
})

t.Run("fails when adding a wrong hash", func(t *testing.T) {
conn, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID)
if err != nil {
_, err = conn.AcceptStream()
require.Error(t, err)
}
_, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID)
require.Error(t, err)
})

require.NoError(t, ln.Close())
Expand Down

0 comments on commit 6795012

Please sign in to comment.