Skip to content

Commit

Permalink
use a mock clock in cert manager tests (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann authored Jul 10, 2022
1 parent d626e80 commit ebcb513
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 20 deletions.
17 changes: 10 additions & 7 deletions p2p/transport/webtransport/cert_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"
"time"

"github.com/benbjohnson/clock"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
Expand Down Expand Up @@ -40,6 +41,7 @@ func newCertConfig(start, end time.Time, conf *tls.Config) (*certConfig, error)
// We continue to remember the hash of (1) for validation during the Noise handshake for another 4 days,
// as the client might be connecting with a cached address.
type certManager struct {
clock clock.Clock
ctx context.Context
ctxCancel context.CancelFunc
refCount sync.WaitGroup
Expand All @@ -53,26 +55,30 @@ type certManager struct {
addrComp ma.Multiaddr
}

func newCertManager(certValidity time.Duration) (*certManager, error) {
func newCertManager(clock clock.Clock, certValidity time.Duration) (*certManager, error) {
m := &certManager{
clock: clock,
certValidity: certValidity,
}
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
if err := m.init(); err != nil {
return nil, err
}

t := m.clock.Ticker(m.certValidity * 4 / 9) // make sure we're a bit faster than 1/2
m.refCount.Add(1)
go func() {
defer m.refCount.Done()
if err := m.background(); err != nil {
defer t.Stop()
if err := m.background(t); err != nil {
log.Fatal(err)
}
}()
return m, nil
}

func (m *certManager) init() error {
start := time.Now()
start := m.clock.Now()
end := start.Add(m.certValidity)
tlsConf, err := getTLSConf(start, end)
if err != nil {
Expand All @@ -86,10 +92,7 @@ func (m *certManager) init() error {
return m.cacheAddrComponent()
}

func (m *certManager) background() error {
t := time.NewTicker(m.certValidity * 4 / 9) // make sure we're a bit faster than 1/2
defer t.Stop()

func (m *certManager) background(t *clock.Ticker) error {
for {
select {
case <-m.ctx.Done():
Expand Down
25 changes: 14 additions & 11 deletions p2p/transport/webtransport/cert_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package libp2pwebtransport
import (
"crypto/sha256"
"crypto/tls"
"os"
"fmt"
"testing"
"time"

"github.com/benbjohnson/clock"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
Expand Down Expand Up @@ -37,15 +38,17 @@ func certHashFromComponent(t *testing.T, comp ma.Component) []byte {
}

func TestInitialCert(t *testing.T) {
m, err := newCertManager(certValidity)
cl := clock.NewMock()
cl.Add(1234567 * time.Hour)
m, err := newCertManager(cl, certValidity)
require.NoError(t, err)
defer m.Close()

conf := m.GetConfig()
require.Len(t, conf.Certificates, 1)
cert := conf.Certificates[0]
require.WithinDuration(t, time.Now(), cert.Leaf.NotBefore, time.Second)
require.WithinDuration(t, time.Now().Add(certValidity), cert.Leaf.NotAfter, time.Second)
require.Equal(t, cl.Now().UTC(), cert.Leaf.NotBefore)
require.Equal(t, cl.Now().Add(certValidity).UTC(), cert.Leaf.NotAfter)
addr := m.AddrComponent()
components := splitMultiaddr(addr)
require.Len(t, components, 1)
Expand All @@ -55,26 +58,26 @@ func TestInitialCert(t *testing.T) {
}

func TestCertRenewal(t *testing.T) {
var certValidity = 300 * time.Millisecond
if os.Getenv("CI") != "" {
certValidity = 2 * time.Second
}
m, err := newCertManager(certValidity)
cl := clock.NewMock()
m, err := newCertManager(cl, certValidity)
require.NoError(t, err)
defer m.Close()

firstConf := m.GetConfig()
require.Len(t, splitMultiaddr(m.AddrComponent()), 1)
// wait for a new certificate to be generated
require.Eventually(t, func() bool { return len(splitMultiaddr(m.AddrComponent())) > 1 }, certValidity/2, 10*time.Millisecond)
fmt.Println("add time")
cl.Add(certValidity / 2)
require.Eventually(t, func() bool { return len(splitMultiaddr(m.AddrComponent())) > 1 }, 200*time.Millisecond, 10*time.Millisecond)
// the actual config used should still be the same, we're just advertising the hash of the next config
components := splitMultiaddr(m.AddrComponent())
require.Len(t, components, 2)
for _, c := range components {
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code)
}
require.Equal(t, firstConf, m.GetConfig())
require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, certValidity/2, 10*time.Millisecond)
cl.Add(certValidity / 2)
require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond)
newConf := m.GetConfig()
// check that the new config now matches the second component
hash := certificateHashFromTLSConfig(newConf)
Expand Down
21 changes: 19 additions & 2 deletions p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

noise "github.com/libp2p/go-libp2p-noise"

"github.com/benbjohnson/clock"
logging "github.com/ipfs/go-log/v2"
"github.com/lucas-clemente/quic-go/http3"
"github.com/marten-seemann/webtransport-go"
Expand All @@ -33,9 +34,19 @@ const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport"

const certValidity = 14 * 24 * time.Hour

type Option func(*transport) error

func WithClock(cl clock.Clock) Option {
return func(t *transport) error {
t.clock = cl
return nil
}
}

type transport struct {
privKey ic.PrivKey
pid peer.ID
clock clock.Clock

rcmgr network.ResourceManager
gater connmgr.ConnectionGater
Expand All @@ -50,7 +61,7 @@ type transport struct {
var _ tpt.Transport = &transport{}
var _ io.Closer = &transport{}

func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Transport, error) {
func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) {
id, err := peer.IDFromPrivateKey(key)
if err != nil {
return nil, err
Expand All @@ -60,6 +71,12 @@ func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceMa
privKey: key,
rcmgr: rcmgr,
gater: gater,
clock: clock.New(),
}
for _, opt := range opts {
if err := opt(t); err != nil {
return nil, err
}
}
noise, err := noise.New(key, noise.WithEarlyDataHandler(t.checkEarlyData))
if err != nil {
Expand Down Expand Up @@ -208,7 +225,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) {
return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr)
}
t.listenOnce.Do(func() {
t.certManager, t.listenOnceErr = newCertManager(certValidity)
t.certManager, t.listenOnceErr = newCertManager(t.clock, certValidity)
})
if t.listenOnceErr != nil {
return nil, t.listenOnceErr
Expand Down

0 comments on commit ebcb513

Please sign in to comment.