diff --git a/config/config.go b/config/config.go index 7b11c2e357..0a68736f64 100644 --- a/config/config.go +++ b/config/config.go @@ -23,11 +23,11 @@ import ( bhost "github.com/libp2p/go-libp2p/p2p/host/basic" blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" routed "github.com/libp2p/go-libp2p/p2p/host/routed" + "github.com/libp2p/go-libp2p/p2p/net/swarm" circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" - swarm "github.com/libp2p/go-libp2p-swarm" tptu "github.com/libp2p/go-libp2p-transport-upgrader" logging "github.com/ipfs/go-log/v2" diff --git a/config/muxer_test.go b/config/muxer_test.go index 50e32ef07d..76aada1061 100644 --- a/config/muxer_test.go +++ b/config/muxer_test.go @@ -3,12 +3,12 @@ package config import ( "testing" - "github.com/libp2p/go-libp2p-core/network" + bhost "github.com/libp2p/go-libp2p/p2p/host/basic" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" - bhost "github.com/libp2p/go-libp2p/p2p/host/basic" yamux "github.com/libp2p/go-libp2p-yamux" ) diff --git a/go.mod b/go.mod index 655228de67..4dcf718c66 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,6 @@ require ( github.com/libp2p/go-libp2p-peerstore v0.6.0 github.com/libp2p/go-libp2p-quic-transport v0.17.0 github.com/libp2p/go-libp2p-resource-manager v0.2.1 - github.com/libp2p/go-libp2p-swarm v0.10.2 github.com/libp2p/go-libp2p-testing v0.9.2 github.com/libp2p/go-libp2p-tls v0.4.1 github.com/libp2p/go-libp2p-transport-upgrader v0.7.1 @@ -35,12 +34,14 @@ require ( github.com/libp2p/zeroconf/v2 v2.1.1 github.com/multiformats/go-multiaddr v0.5.0 github.com/multiformats/go-multiaddr-dns v0.3.1 + github.com/multiformats/go-multiaddr-fmt v0.1.0 github.com/multiformats/go-multihash v0.1.0 github.com/multiformats/go-multistream v0.3.0 github.com/multiformats/go-varint v0.0.6 github.com/raulk/go-watchdog v1.2.0 github.com/stretchr/testify v1.7.0 github.com/whyrusleeping/mdns v0.0.0-20190826153040-b9b60ed33aa9 + github.com/whyrusleeping/multiaddr-filter v0.0.0-20160516205228-e903e4adabd7 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c ) @@ -76,6 +77,7 @@ require ( github.com/libp2p/go-flow-metrics v0.0.3 // indirect github.com/libp2p/go-libp2p-blankhost v0.3.0 // indirect github.com/libp2p/go-libp2p-pnet v0.2.0 // indirect + github.com/libp2p/go-libp2p-swarm v0.10.2 // indirect github.com/libp2p/go-nat v0.1.0 // indirect github.com/libp2p/go-openssl v0.0.7 // indirect github.com/libp2p/go-reuseport v0.1.0 // indirect @@ -96,7 +98,6 @@ require ( github.com/mr-tron/base58 v1.2.0 // indirect github.com/multiformats/go-base32 v0.0.4 // indirect github.com/multiformats/go-base36 v0.1.0 // indirect - github.com/multiformats/go-multiaddr-fmt v0.1.0 // indirect github.com/multiformats/go-multibase v0.0.3 // indirect github.com/multiformats/go-multicodec v0.4.1 // indirect github.com/nxadm/tail v1.4.8 // indirect @@ -112,7 +113,6 @@ require ( github.com/raulk/clock v1.1.0 // indirect github.com/spacemonkeygo/spacelog v0.0.0-20180420211403-2296661a0572 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect - github.com/whyrusleeping/multiaddr-filter v0.0.0-20160516205228-e903e4adabd7 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.8.0 // indirect go.uber.org/zap v1.21.0 // indirect diff --git a/p2p/discovery/backoff/backoffcache_test.go b/p2p/discovery/backoff/backoffcache_test.go index f297ddfbb0..920b45924c 100644 --- a/p2p/discovery/backoff/backoffcache_test.go +++ b/p2p/discovery/backoff/backoffcache_test.go @@ -9,11 +9,10 @@ import ( "github.com/libp2p/go-libp2p/p2p/discovery/mocks" bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p-core/discovery" "github.com/libp2p/go-libp2p-core/peer" - - swarmt "github.com/libp2p/go-libp2p-swarm/testing" ) type delayedDiscovery struct { diff --git a/p2p/discovery/backoff/backoffconnector_test.go b/p2p/discovery/backoff/backoffconnector_test.go index b421658dac..6bf958f9c9 100644 --- a/p2p/discovery/backoff/backoffconnector_test.go +++ b/p2p/discovery/backoff/backoffconnector_test.go @@ -8,12 +8,11 @@ import ( "time" bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" - "github.com/stretchr/testify/require" ) diff --git a/p2p/discovery/mdns_legacy/mdns_test.go b/p2p/discovery/mdns_legacy/mdns_test.go index 0e667e149f..17a5fe77a0 100644 --- a/p2p/discovery/mdns_legacy/mdns_test.go +++ b/p2p/discovery/mdns_legacy/mdns_test.go @@ -5,12 +5,13 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" + bhost "github.com/libp2p/go-libp2p/p2p/host/basic" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" - bhost "github.com/libp2p/go-libp2p/p2p/host/basic" + + "github.com/stretchr/testify/require" ) type DiscoveryNotifee struct { @@ -22,7 +23,7 @@ func (n *DiscoveryNotifee) HandlePeerFound(pi peer.AddrInfo) { } func TestMdnsDiscovery(t *testing.T) { - //TODO: re-enable when the new lib will get integrated + // TODO: re-enable when the new lib will get integrated t.Skip("TestMdnsDiscovery fails randomly with current lib") ctx, cancel := context.WithCancel(context.Background()) diff --git a/p2p/discovery/routing/routing_test.go b/p2p/discovery/routing/routing_test.go index 13f4da3adf..3f88297237 100644 --- a/p2p/discovery/routing/routing_test.go +++ b/p2p/discovery/routing/routing_test.go @@ -9,12 +9,12 @@ import ( "github.com/libp2p/go-libp2p/p2p/discovery/mocks" "github.com/libp2p/go-libp2p/p2p/discovery/util" bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/ipfs/go-cid" "github.com/libp2p/go-libp2p-core/discovery" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" ) type mockRoutingTable struct { diff --git a/p2p/host/autonat/autonat_test.go b/p2p/host/autonat/autonat_test.go index 7cf6a8efd7..7f111bd30b 100644 --- a/p2p/host/autonat/autonat_test.go +++ b/p2p/host/autonat/autonat_test.go @@ -7,13 +7,13 @@ import ( pb "github.com/libp2p/go-libp2p/p2p/host/autonat/pb" bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p-core/event" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" "github.com/libp2p/go-msgio/protoio" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" diff --git a/p2p/host/autonat/dialpolicy_test.go b/p2p/host/autonat/dialpolicy_test.go index 8ee70dbb6a..afa047eb73 100644 --- a/p2p/host/autonat/dialpolicy_test.go +++ b/p2p/host/autonat/dialpolicy_test.go @@ -6,10 +6,12 @@ import ( "net" "testing" + blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/transport" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" - blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" + "github.com/multiformats/go-multiaddr" ) diff --git a/p2p/host/autonat/svc_test.go b/p2p/host/autonat/svc_test.go index 56d8dd7d92..f8768689d9 100644 --- a/p2p/host/autonat/svc_test.go +++ b/p2p/host/autonat/svc_test.go @@ -7,13 +7,12 @@ import ( "time" bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p-core/event" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" - ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 9b229b4c03..bfda432f55 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -11,7 +11,10 @@ import ( "testing" "time" - "github.com/libp2p/go-eventbus" + "github.com/libp2p/go-libp2p/p2p/host/autonat" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p/p2p/protocol/identify" + "github.com/libp2p/go-libp2p-core/event" "github.com/libp2p/go-libp2p-core/helpers" "github.com/libp2p/go-libp2p-core/host" @@ -21,10 +24,8 @@ import ( "github.com/libp2p/go-libp2p-core/protocol" "github.com/libp2p/go-libp2p-core/record" "github.com/libp2p/go-libp2p-core/test" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" - "github.com/libp2p/go-libp2p/p2p/host/autonat" - "github.com/libp2p/go-libp2p/p2p/protocol/identify" + "github.com/libp2p/go-eventbus" ma "github.com/multiformats/go-multiaddr" madns "github.com/multiformats/go-multiaddr-dns" diff --git a/p2p/host/basic/peer_connectedness_test.go b/p2p/host/basic/peer_connectedness_test.go index 818499de19..fc8ecb60a8 100644 --- a/p2p/host/basic/peer_connectedness_test.go +++ b/p2p/host/basic/peer_connectedness_test.go @@ -5,10 +5,11 @@ import ( "testing" "time" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p-core/event" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" "github.com/stretchr/testify/require" ) diff --git a/p2p/host/blank/peer_connectedness_test.go b/p2p/host/blank/peer_connectedness_test.go index 88c8f2861a..cabf3fe14a 100644 --- a/p2p/host/blank/peer_connectedness_test.go +++ b/p2p/host/blank/peer_connectedness_test.go @@ -5,10 +5,11 @@ import ( "testing" "time" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p-core/event" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" "github.com/stretchr/testify/require" ) diff --git a/p2p/net/swarm/addrs.go b/p2p/net/swarm/addrs.go new file mode 100644 index 0000000000..09f8df766c --- /dev/null +++ b/p2p/net/swarm/addrs.go @@ -0,0 +1,35 @@ +package swarm + +import ( + ma "github.com/multiformats/go-multiaddr" + mamask "github.com/whyrusleeping/multiaddr-filter" +) + +// http://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml +var lowTimeoutFilters = ma.NewFilters() + +func init() { + for _, p := range []string{ + "/ip4/10.0.0.0/ipcidr/8", + "/ip4/100.64.0.0/ipcidr/10", + "/ip4/169.254.0.0/ipcidr/16", + "/ip4/172.16.0.0/ipcidr/12", + "/ip4/192.0.0.0/ipcidr/24", + "/ip4/192.0.0.0/ipcidr/29", + "/ip4/192.0.0.8/ipcidr/32", + "/ip4/192.0.0.170/ipcidr/32", + "/ip4/192.0.0.171/ipcidr/32", + "/ip4/192.0.2.0/ipcidr/24", + "/ip4/192.168.0.0/ipcidr/16", + "/ip4/198.18.0.0/ipcidr/15", + "/ip4/198.51.100.0/ipcidr/24", + "/ip4/203.0.113.0/ipcidr/24", + "/ip4/240.0.0.0/ipcidr/4", + } { + f, err := mamask.NewMask(p) + if err != nil { + panic("error in lowTimeoutFilters init: " + err.Error()) + } + lowTimeoutFilters.AddFilter(*f, ma.ActionDeny) + } +} diff --git a/p2p/net/swarm/dial_error.go b/p2p/net/swarm/dial_error.go new file mode 100644 index 0000000000..f2986348bf --- /dev/null +++ b/p2p/net/swarm/dial_error.go @@ -0,0 +1,71 @@ +package swarm + +import ( + "fmt" + "os" + "strings" + + "github.com/libp2p/go-libp2p-core/peer" + + ma "github.com/multiformats/go-multiaddr" +) + +// maxDialDialErrors is the maximum number of dial errors we record +const maxDialDialErrors = 16 + +// DialError is the error type returned when dialing. +type DialError struct { + Peer peer.ID + DialErrors []TransportError + Cause error + Skipped int +} + +func (e *DialError) Timeout() bool { + return os.IsTimeout(e.Cause) +} + +func (e *DialError) recordErr(addr ma.Multiaddr, err error) { + if len(e.DialErrors) >= maxDialDialErrors { + e.Skipped++ + return + } + e.DialErrors = append(e.DialErrors, TransportError{ + Address: addr, + Cause: err, + }) +} + +func (e *DialError) Error() string { + var builder strings.Builder + fmt.Fprintf(&builder, "failed to dial %s:", e.Peer) + if e.Cause != nil { + fmt.Fprintf(&builder, " %s", e.Cause) + } + for _, te := range e.DialErrors { + fmt.Fprintf(&builder, "\n * [%s] %s", te.Address, te.Cause) + } + if e.Skipped > 0 { + fmt.Fprintf(&builder, "\n ... skipping %d errors ...", e.Skipped) + } + return builder.String() +} + +// Unwrap implements https://godoc.org/golang.org/x/xerrors#Wrapper. +func (e *DialError) Unwrap() error { + return e.Cause +} + +var _ error = (*DialError)(nil) + +// TransportError is the error returned when dialing a specific address. +type TransportError struct { + Address ma.Multiaddr + Cause error +} + +func (e *TransportError) Error() string { + return fmt.Sprintf("failed to dial %s: %s", e.Address, e.Cause) +} + +var _ error = (*TransportError)(nil) diff --git a/p2p/net/swarm/dial_sync.go b/p2p/net/swarm/dial_sync.go new file mode 100644 index 0000000000..f24ecd3c41 --- /dev/null +++ b/p2p/net/swarm/dial_sync.go @@ -0,0 +1,109 @@ +package swarm + +import ( + "context" + "sync" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" +) + +// dialWorkerFunc is used by dialSync to spawn a new dial worker +type dialWorkerFunc func(peer.ID, <-chan dialRequest) + +// newDialSync constructs a new dialSync +func newDialSync(worker dialWorkerFunc) *dialSync { + return &dialSync{ + dials: make(map[peer.ID]*activeDial), + dialWorker: worker, + } +} + +// dialSync is a dial synchronization helper that ensures that at most one dial +// to any given peer is active at any given time. +type dialSync struct { + mutex sync.Mutex + dials map[peer.ID]*activeDial + dialWorker dialWorkerFunc +} + +type activeDial struct { + refCnt int + + ctx context.Context + cancel func() + + reqch chan dialRequest +} + +func (ad *activeDial) close() { + ad.cancel() + close(ad.reqch) +} + +func (ad *activeDial) dial(ctx context.Context) (*Conn, error) { + dialCtx := ad.ctx + + if forceDirect, reason := network.GetForceDirectDial(ctx); forceDirect { + dialCtx = network.WithForceDirectDial(dialCtx, reason) + } + if simConnect, isClient, reason := network.GetSimultaneousConnect(ctx); simConnect { + dialCtx = network.WithSimultaneousConnect(dialCtx, isClient, reason) + } + + resch := make(chan dialResponse, 1) + select { + case ad.reqch <- dialRequest{ctx: dialCtx, resch: resch}: + case <-ctx.Done(): + return nil, ctx.Err() + } + + select { + case res := <-resch: + return res.conn, res.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (ds *dialSync) getActiveDial(p peer.ID) (*activeDial, error) { + ds.mutex.Lock() + defer ds.mutex.Unlock() + + actd, ok := ds.dials[p] + if !ok { + // This code intentionally uses the background context. Otherwise, if the first call + // to Dial is canceled, subsequent dial calls will also be canceled. + ctx, cancel := context.WithCancel(context.Background()) + actd = &activeDial{ + ctx: ctx, + cancel: cancel, + reqch: make(chan dialRequest), + } + go ds.dialWorker(p, actd.reqch) + ds.dials[p] = actd + } + // increase ref count before dropping mutex + actd.refCnt++ + return actd, nil +} + +// Dial initiates a dial to the given peer if there are none in progress +// then waits for the dial to that peer to complete. +func (ds *dialSync) Dial(ctx context.Context, p peer.ID) (*Conn, error) { + ad, err := ds.getActiveDial(p) + if err != nil { + return nil, err + } + + defer func() { + ds.mutex.Lock() + defer ds.mutex.Unlock() + ad.refCnt-- + if ad.refCnt == 0 { + ad.close() + delete(ds.dials, p) + } + }() + return ad.dial(ctx) +} diff --git a/p2p/net/swarm/dial_sync_test.go b/p2p/net/swarm/dial_sync_test.go new file mode 100644 index 0000000000..0d9c6ca413 --- /dev/null +++ b/p2p/net/swarm/dial_sync_test.go @@ -0,0 +1,231 @@ +package swarm + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/libp2p/go-libp2p-core/peer" +) + +func getMockDialFunc() (dialWorkerFunc, func(), context.Context, <-chan struct{}) { + dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care + dialctx, cancel := context.WithCancel(context.Background()) + ch := make(chan struct{}) + f := func(p peer.ID, reqch <-chan dialRequest) { + defer cancel() + dfcalls <- struct{}{} + go func() { + for req := range reqch { + <-ch + req.resch <- dialResponse{conn: new(Conn)} + } + }() + } + + var once sync.Once + return f, func() { once.Do(func() { close(ch) }) }, dialctx, dfcalls +} + +func TestBasicDialSync(t *testing.T) { + df, done, _, callsch := getMockDialFunc() + dsync := newDialSync(df) + p := peer.ID("testpeer") + + finished := make(chan struct{}, 2) + go func() { + if _, err := dsync.Dial(context.Background(), p); err != nil { + t.Error(err) + } + finished <- struct{}{} + }() + + go func() { + if _, err := dsync.Dial(context.Background(), p); err != nil { + t.Error(err) + } + finished <- struct{}{} + }() + + // short sleep just to make sure we've moved around in the scheduler + time.Sleep(time.Millisecond * 20) + done() + + <-finished + <-finished + + if len(callsch) > 1 { + t.Fatal("should only have called dial func once!") + } +} + +func TestDialSyncCancel(t *testing.T) { + df, done, _, dcall := getMockDialFunc() + + dsync := newDialSync(df) + + p := peer.ID("testpeer") + + ctx1, cancel1 := context.WithCancel(context.Background()) + + finished := make(chan struct{}) + go func() { + _, err := dsync.Dial(ctx1, p) + if err != ctx1.Err() { + t.Error("should have gotten context error") + } + finished <- struct{}{} + }() + + // make sure the above makes it through the wait code first + select { + case <-dcall: + case <-time.After(time.Second): + t.Fatal("timed out waiting for dial to start") + } + + // Add a second dialwait in so two actors are waiting on the same dial + go func() { + _, err := dsync.Dial(context.Background(), p) + if err != nil { + t.Error(err) + } + finished <- struct{}{} + }() + + time.Sleep(time.Millisecond * 20) + + // cancel the first dialwait, it should not affect the second at all + cancel1() + select { + case <-finished: + case <-time.After(time.Second): + t.Fatal("timed out waiting for wait to exit") + } + + // short sleep just to make sure we've moved around in the scheduler + time.Sleep(time.Millisecond * 20) + done() + + <-finished +} + +func TestDialSyncAllCancel(t *testing.T) { + df, done, dctx, _ := getMockDialFunc() + + dsync := newDialSync(df) + p := peer.ID("testpeer") + ctx, cancel := context.WithCancel(context.Background()) + + finished := make(chan struct{}) + go func() { + if _, err := dsync.Dial(ctx, p); err != ctx.Err() { + t.Error("should have gotten context error") + } + finished <- struct{}{} + }() + + // Add a second dialwait in so two actors are waiting on the same dial + go func() { + if _, err := dsync.Dial(ctx, p); err != ctx.Err() { + t.Error("should have gotten context error") + } + finished <- struct{}{} + }() + + cancel() + for i := 0; i < 2; i++ { + select { + case <-finished: + case <-time.After(time.Second): + t.Fatal("timed out waiting for wait to exit") + } + } + + // the dial should have exited now + select { + case <-dctx.Done(): + case <-time.After(time.Second): + t.Fatal("timed out waiting for dial to return") + } + + // should be able to successfully dial that peer again + done() + if _, err := dsync.Dial(context.Background(), p); err != nil { + t.Fatal(err) + } +} + +func TestFailFirst(t *testing.T) { + var count int32 + f := func(p peer.ID, reqch <-chan dialRequest) { + go func() { + for { + req, ok := <-reqch + if !ok { + return + } + + if atomic.LoadInt32(&count) > 0 { + req.resch <- dialResponse{conn: new(Conn)} + } else { + req.resch <- dialResponse{err: fmt.Errorf("gophers ate the modem")} + } + atomic.AddInt32(&count, 1) + } + }() + } + + ds := newDialSync(f) + p := peer.ID("testing") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + if _, err := ds.Dial(ctx, p); err == nil { + t.Fatal("expected gophers to have eaten the modem") + } + + c, err := ds.Dial(ctx, p) + if err != nil { + t.Fatal(err) + } + if c == nil { + t.Fatal("should have gotten a 'real' conn back") + } +} + +func TestStressActiveDial(t *testing.T) { + ds := newDialSync(func(p peer.ID, reqch <-chan dialRequest) { + go func() { + for { + req, ok := <-reqch + if !ok { + return + } + req.resch <- dialResponse{} + } + }() + }) + + wg := sync.WaitGroup{} + + pid := peer.ID("foo") + + makeDials := func() { + for i := 0; i < 10000; i++ { + ds.Dial(context.Background(), pid) + } + wg.Done() + } + + for i := 0; i < 100; i++ { + wg.Add(1) + go makeDials() + } + + wg.Wait() +} diff --git a/p2p/net/swarm/dial_test.go b/p2p/net/swarm/dial_test.go new file mode 100644 index 0000000000..0a98d69e43 --- /dev/null +++ b/p2p/net/swarm/dial_test.go @@ -0,0 +1,616 @@ +package swarm_test + +import ( + "context" + "net" + "sync" + "testing" + "time" + + . "github.com/libp2p/go-libp2p/p2p/net/swarm" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peerstore" + testutil "github.com/libp2p/go-libp2p-core/test" + + "github.com/libp2p/go-libp2p-testing/ci" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + + "github.com/stretchr/testify/require" +) + +func closeSwarms(swarms []*Swarm) { + for _, s := range swarms { + s.Close() + } +} + +func TestBasicDialPeer(t *testing.T) { + t.Parallel() + + swarms := makeSwarms(t, 2) + defer closeSwarms(swarms) + s1 := swarms[0] + s2 := swarms[1] + + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) + + c, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + + s, err := c.NewStream(context.Background()) + require.NoError(t, err) + s.Close() +} + +func TestDialWithNoListeners(t *testing.T) { + t.Parallel() + + s1 := makeDialOnlySwarm(t) + swarms := makeSwarms(t, 1) + defer closeSwarms(swarms) + s2 := swarms[0] + + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) + + c, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + + s, err := c.NewStream(context.Background()) + require.NoError(t, err) + s.Close() +} + +func acceptAndHang(l net.Listener) { + conns := make([]net.Conn, 0, 10) + for { + c, err := l.Accept() + if err != nil { + break + } + if c != nil { + conns = append(conns, c) + } + } + for _, c := range conns { + c.Close() + } +} + +func TestSimultDials(t *testing.T) { + t.Parallel() + + ctx := context.Background() + swarms := makeSwarms(t, 2, swarmt.OptDisableReuseport) + defer closeSwarms(swarms) + + // connect everyone + { + var wg sync.WaitGroup + errs := make(chan error, 20) // 2 connect calls in each of the 10 for-loop iterations + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + // copy for other peer + log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) + s.Peerstore().AddAddr(dst, addr, peerstore.TempAddrTTL) + if _, err := s.DialPeer(ctx, dst); err != nil { + errs <- err + } + wg.Done() + } + + ifaceAddrs0, err := swarms[0].InterfaceListenAddresses() + if err != nil { + t.Fatal(err) + } + ifaceAddrs1, err := swarms[1].InterfaceListenAddresses() + if err != nil { + t.Fatal(err) + } + + log.Info("Connecting swarms simultaneously.") + for i := 0; i < 10; i++ { // connect 10x for each. + wg.Add(2) + go connect(swarms[0], swarms[1].LocalPeer(), ifaceAddrs1[0]) + go connect(swarms[1], swarms[0].LocalPeer(), ifaceAddrs0[0]) + } + wg.Wait() + close(errs) + + for err := range errs { + if err != nil { + t.Fatal("error swarm dialing to peer", err) + } + } + } + + // should still just have 1, at most 2 connections :) + c01l := len(swarms[0].ConnsToPeer(swarms[1].LocalPeer())) + if c01l > 2 { + t.Error("0->1 has", c01l) + } + c10l := len(swarms[1].ConnsToPeer(swarms[0].LocalPeer())) + if c10l > 2 { + t.Error("1->0 has", c10l) + } +} + +func newSilentPeer(t *testing.T) (peer.ID, ma.Multiaddr, net.Listener) { + dst := testutil.RandPeerIDFatal(t) + lst, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + addr, err := manet.FromNetAddr(lst.Addr()) + if err != nil { + t.Fatal(err) + } + addrs, err := manet.ResolveUnspecifiedAddresses([]ma.Multiaddr{addr}, nil) + if err != nil { + t.Fatal(err) + } + t.Log("new silent peer:", dst, addrs[0]) + return dst, addrs[0], lst +} + +func TestDialWait(t *testing.T) { + const dialTimeout = 250 * time.Millisecond + + swarms := makeSwarms(t, 1, swarmt.DialTimeout(dialTimeout)) + s1 := swarms[0] + defer s1.Close() + + // dial to a non-existent peer. + s2p, s2addr, s2l := newSilentPeer(t) + go acceptAndHang(s2l) + defer s2l.Close() + s1.Peerstore().AddAddr(s2p, s2addr, peerstore.PermanentAddrTTL) + + before := time.Now() + if c, err := s1.DialPeer(context.Background(), s2p); err == nil { + defer c.Close() + t.Fatal("error swarm dialing to unknown peer worked...", err) + } else { + t.Log("correctly got error:", err) + } + duration := time.Since(before) + + if duration < dialTimeout*DialAttempts { + t.Error("< dialTimeout * DialAttempts not being respected", duration, dialTimeout*DialAttempts) + } + if duration > 2*dialTimeout*DialAttempts { + t.Error("> 2*dialTimeout * DialAttempts not being respected", duration, 2*dialTimeout*DialAttempts) + } + + if !s1.Backoff().Backoff(s2p, s2addr) { + t.Error("s2 should now be on backoff") + } +} + +func TestDialBackoff(t *testing.T) { + if ci.IsRunning() { + t.Skip("travis will never have fun with this test") + } + t.Parallel() + + const dialTimeout = 100 * time.Millisecond + + ctx := context.Background() + swarms := makeSwarms(t, 2, swarmt.DialTimeout(dialTimeout)) + defer closeSwarms(swarms) + s1 := swarms[0] + s2 := swarms[1] + + s2addrs, err := s2.InterfaceListenAddresses() + require.NoError(t, err) + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2addrs, peerstore.PermanentAddrTTL) + + // dial to a non-existent peer. + s3p, s3addr, s3l := newSilentPeer(t) + go acceptAndHang(s3l) + defer s3l.Close() + s1.Peerstore().AddAddr(s3p, s3addr, peerstore.PermanentAddrTTL) + + // in this test we will: + // 1) dial 10x to each node. + // 2) all dials should hang + // 3) s1->s2 should succeed. + // 4) s1->s3 should not (and should place s3 on backoff) + // 5) disconnect entirely + // 6) dial 10x to each node again + // 7) s3 dials should all return immediately (except 1) + // 8) s2 dials should all hang, and succeed + // 9) last s3 dial ends, unsuccessful + + dialOnlineNode := func(dst peer.ID, times int) <-chan bool { + ch := make(chan bool) + for i := 0; i < times; i++ { + go func() { + if _, err := s1.DialPeer(ctx, dst); err != nil { + t.Error("error dialing", dst, err) + ch <- false + } else { + ch <- true + } + }() + } + return ch + } + + dialOfflineNode := func(dst peer.ID, times int) <-chan bool { + ch := make(chan bool) + for i := 0; i < times; i++ { + go func() { + if c, err := s1.DialPeer(ctx, dst); err != nil { + ch <- false + } else { + t.Error("succeeded in dialing", dst) + ch <- true + c.Close() + } + }() + } + return ch + } + + { + // 1) dial 10x to each node. + N := 10 + s2done := dialOnlineNode(s2.LocalPeer(), N) + s3done := dialOfflineNode(s3p, N) + + // when all dials should be done by: + dialTimeout1x := time.After(dialTimeout) + dialTimeout10Ax := time.After(dialTimeout * 2 * 10) // DialAttempts * 10) + + // 2) all dials should hang + select { + case <-s2done: + t.Error("s2 should not happen immediately") + case <-s3done: + t.Error("s3 should not happen yet") + case <-time.After(time.Millisecond): + // s2 may finish very quickly, so let's get out. + } + + // 3) s1->s2 should succeed. + for i := 0; i < N; i++ { + select { + case r := <-s2done: + if !r { + t.Error("s2 should not fail") + } + case <-s3done: + t.Error("s3 should not happen yet") + case <-dialTimeout1x: + t.Error("s2 took too long") + } + } + + select { + case <-s2done: + t.Error("s2 should have no more") + case <-s3done: + t.Error("s3 should not happen yet") + case <-dialTimeout1x: // let it pass + } + + // 4) s1->s3 should not (and should place s3 on backoff) + // N-1 should finish before dialTimeout1x * 2 + for i := 0; i < N; i++ { + select { + case <-s2done: + t.Error("s2 should have no more") + case r := <-s3done: + if r { + t.Error("s3 should not succeed") + } + case <-(dialTimeout1x): + if i < (N - 1) { + t.Fatal("s3 took too long") + } + t.Log("dialTimeout1x * 1.3 hit for last peer") + case <-dialTimeout10Ax: + t.Fatal("s3 took too long") + } + } + + // check backoff state + if s1.Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { + t.Error("s2 should not be on backoff") + } + if !s1.Backoff().Backoff(s3p, s3addr) { + t.Error("s3 should be on backoff") + } + + // 5) disconnect entirely + + for _, c := range s1.Conns() { + c.Close() + } + for i := 0; i < 100 && len(s1.Conns()) > 0; i++ { + <-time.After(time.Millisecond) + } + if len(s1.Conns()) > 0 { + t.Fatal("s1 conns must exit") + } + } + + { + // 6) dial 10x to each node again + N := 10 + s2done := dialOnlineNode(s2.LocalPeer(), N) + s3done := dialOfflineNode(s3p, N) + + // when all dials should be done by: + dialTimeout1x := time.After(dialTimeout) + dialTimeout10Ax := time.After(dialTimeout * 2 * 10) // DialAttempts * 10) + + // 7) s3 dials should all return immediately (except 1) + for i := 0; i < N-1; i++ { + select { + case <-s2done: + t.Error("s2 should not succeed yet") + case r := <-s3done: + if r { + t.Error("s3 should not succeed") + } + case <-dialTimeout1x: + t.Fatal("s3 took too long") + } + } + + // 8) s2 dials should all hang, and succeed + for i := 0; i < N; i++ { + select { + case r := <-s2done: + if !r { + t.Error("s2 should succeed") + } + // case <-s3done: + case <-(dialTimeout1x): + t.Fatal("s3 took too long") + } + } + + // 9) the last s3 should return, failed. + select { + case <-s2done: + t.Error("s2 should have no more") + case r := <-s3done: + if r { + t.Error("s3 should not succeed") + } + case <-dialTimeout10Ax: + t.Fatal("s3 took too long") + } + + // check backoff state (the same) + if s1.Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { + t.Error("s2 should not be on backoff") + } + if !s1.Backoff().Backoff(s3p, s3addr) { + t.Error("s3 should be on backoff") + } + } +} + +func TestDialBackoffClears(t *testing.T) { + t.Parallel() + + const dialTimeout = 250 * time.Millisecond + swarms := makeSwarms(t, 2, swarmt.DialTimeout(dialTimeout)) + defer closeSwarms(swarms) + s1 := swarms[0] + s2 := swarms[1] + + // use another address first, that accept and hang on conns + _, s2bad, s2l := newSilentPeer(t) + go acceptAndHang(s2l) + defer s2l.Close() + + // phase 1 -- dial to non-operational addresses + s1.Peerstore().AddAddr(s2.LocalPeer(), s2bad, peerstore.PermanentAddrTTL) + + before := time.Now() + _, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.Error(t, err, "dialing to broken addr worked...") + duration := time.Since(before) + + if duration < dialTimeout*DialAttempts { + t.Error("< dialTimeout * DialAttempts not being respected", duration, dialTimeout*DialAttempts) + } + if duration > 2*dialTimeout*DialAttempts { + t.Error("> 2*dialTimeout * DialAttempts not being respected", duration, 2*dialTimeout*DialAttempts) + } + require.True(t, s1.Backoff().Backoff(s2.LocalPeer(), s2bad), "s2 should now be on backoff") + + // phase 2 -- add the working address. dial should succeed. + ifaceAddrs1, err := s2.InterfaceListenAddresses() + require.NoError(t, err) + s1.Peerstore().AddAddrs(s2.LocalPeer(), ifaceAddrs1, peerstore.PermanentAddrTTL) + + // backoffs are per address, not peer + c, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + defer c.Close() + require.False(t, s1.Backoff().Backoff(s2.LocalPeer(), s2bad), "s2 should no longer be on backoff") +} + +func TestDialPeerFailed(t *testing.T) { + t.Parallel() + + swarms := makeSwarms(t, 2, swarmt.DialTimeout(100*time.Millisecond)) + defer closeSwarms(swarms) + testedSwarm, targetSwarm := swarms[0], swarms[1] + + const expectedErrorsCount = 5 + for i := 0; i < expectedErrorsCount; i++ { + _, silentPeerAddress, silentPeerListener := newSilentPeer(t) + go acceptAndHang(silentPeerListener) + defer silentPeerListener.Close() + + testedSwarm.Peerstore().AddAddr(targetSwarm.LocalPeer(), silentPeerAddress, peerstore.PermanentAddrTTL) + } + + _, err := testedSwarm.DialPeer(context.Background(), targetSwarm.LocalPeer()) + require.Error(t, err) + + // dial_test.go:508: correctly get a combined error: failed to dial PEER: all dials failed + // * [/ip4/127.0.0.1/tcp/46485] failed to negotiate security protocol: context deadline exceeded + // * [/ip4/127.0.0.1/tcp/34881] failed to negotiate security protocol: context deadline exceeded + // ... + + dialErr, ok := err.(*DialError) + if !ok { + t.Fatalf("expected *DialError, got %T", err) + } + + if len(dialErr.DialErrors) != expectedErrorsCount { + t.Errorf("expected %d errors, got %d", expectedErrorsCount, len(dialErr.DialErrors)) + } +} + +func TestDialExistingConnection(t *testing.T) { + swarms := makeSwarms(t, 2) + defer closeSwarms(swarms) + s1 := swarms[0] + s2 := swarms[1] + + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) + + c1, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + + c2, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + + require.Equal(t, c1, c2, "expecting the same connection from both dials") +} + +func newSilentListener(t *testing.T) ([]ma.Multiaddr, net.Listener) { + lst, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + addr, err := manet.FromNetAddr(lst.Addr()) + if err != nil { + t.Fatal(err) + } + addrs, err := manet.ResolveUnspecifiedAddresses([]ma.Multiaddr{addr}, nil) + if err != nil { + t.Fatal(err) + } + return addrs, lst + +} + +func TestDialSimultaneousJoin(t *testing.T) { + const dialTimeout = 250 * time.Millisecond + + swarms := makeSwarms(t, 2, swarmt.DialTimeout(dialTimeout)) + defer closeSwarms(swarms) + s1 := swarms[0] + s2 := swarms[1] + + s2silentAddrs, s2silentListener := newSilentListener(t) + go acceptAndHang(s2silentListener) + + connch := make(chan network.Conn, 512) + errs := make(chan error, 2) + + // start a dial to s2 through the silent addr + go func() { + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2silentAddrs, peerstore.PermanentAddrTTL) + + c, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + if err != nil { + errs <- err + connch <- nil + return + } + + t.Logf("first dial succedded; conn: %+v", c) + + connch <- c + errs <- nil + }() + + // wait a bit for the dial to take hold + time.Sleep(100 * time.Millisecond) + + // start a second dial to s2 that uses the real s2 addrs + go func() { + s2addrs, err := s2.InterfaceListenAddresses() + if err != nil { + errs <- err + return + } + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2addrs[:1], peerstore.PermanentAddrTTL) + + c, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + if err != nil { + errs <- err + connch <- nil + return + } + + t.Logf("second dial succedded; conn: %+v", c) + + connch <- c + errs <- nil + }() + + // wait for the second dial to finish + c2 := <-connch + + // start a third dial to s2, this should get the existing connection from the successful dial + go func() { + c, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + if err != nil { + errs <- err + connch <- nil + return + } + + t.Logf("third dial succedded; conn: %+v", c) + + connch <- c + errs <- nil + }() + + c3 := <-connch + + // raise any errors from the previous goroutines + for i := 0; i < 3; i++ { + require.NoError(t, <-errs) + } + + if c2 != c3 { + t.Fatal("expected c2 and c3 to be the same") + } + + // next, the first dial to s2, using the silent addr should timeout; at this point the dial + // will error but the last chance check will see the existing connection and return it + select { + case c1 := <-connch: + if c1 != c2 { + t.Fatal("expected c1 and c2 to be the same") + } + case <-time.After(2 * dialTimeout): + t.Fatal("no connection from first dial") + } +} + +func TestDialSelf(t *testing.T) { + t.Parallel() + + swarms := makeSwarms(t, 2) + defer closeSwarms(swarms) + s1 := swarms[0] + + _, err := s1.DialPeer(context.Background(), s1.LocalPeer()) + require.ErrorIs(t, err, ErrDialToSelf, "expected error from self dial") +} diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go new file mode 100644 index 0000000000..8d5d434197 --- /dev/null +++ b/p2p/net/swarm/dial_worker.go @@ -0,0 +1,316 @@ +package swarm + +import ( + "context" + "sync" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +// ///////////////////////////////////////////////////////////////////////////////// +// lo and behold, The Dialer +// TODO explain how all this works +// //////////////////////////////////////////////////////////////////////////////// + +type dialRequest struct { + ctx context.Context + resch chan dialResponse +} + +type dialResponse struct { + conn *Conn + err error +} + +type pendRequest struct { + req dialRequest // the original request + err *DialError // dial error accumulator + addrs map[ma.Multiaddr]struct{} // pending addr dials +} + +type addrDial struct { + addr ma.Multiaddr + ctx context.Context + conn *Conn + err error + requests []int + dialed bool +} + +type dialWorker struct { + s *Swarm + peer peer.ID + reqch <-chan dialRequest + reqno int + requests map[int]*pendRequest + pending map[ma.Multiaddr]*addrDial + resch chan dialResult + + connected bool // true when a connection has been successfully established + + nextDial []ma.Multiaddr + + // ready when we have more addresses to dial (nextDial is not empty) + triggerDial <-chan struct{} + + // for testing + wg sync.WaitGroup +} + +func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest) *dialWorker { + return &dialWorker{ + s: s, + peer: p, + reqch: reqch, + requests: make(map[int]*pendRequest), + pending: make(map[ma.Multiaddr]*addrDial), + resch: make(chan dialResult), + } +} + +func (w *dialWorker) loop() { + w.wg.Add(1) + defer w.wg.Done() + defer w.s.limiter.clearAllPeerDials(w.peer) + + // used to signal readiness to dial and completion of the dial + ready := make(chan struct{}) + close(ready) + +loop: + for { + select { + case req, ok := <-w.reqch: + if !ok { + return + } + + c := w.s.bestAcceptableConnToPeer(req.ctx, w.peer) + if c != nil { + req.resch <- dialResponse{conn: c} + continue loop + } + + addrs, err := w.s.addrsForDial(req.ctx, w.peer) + if err != nil { + req.resch <- dialResponse{err: err} + continue loop + } + + // at this point, len(addrs) > 0 or else it would be error from addrsForDial + // ranke them to process in order + addrs = w.rankAddrs(addrs) + + // create the pending request object + pr := &pendRequest{ + req: req, + err: &DialError{Peer: w.peer}, + addrs: make(map[ma.Multiaddr]struct{}), + } + for _, a := range addrs { + pr.addrs[a] = struct{}{} + } + + // check if any of the addrs has been successfully dialed and accumulate + // errors from complete dials while collecting new addrs to dial/join + var todial []ma.Multiaddr + var tojoin []*addrDial + + for _, a := range addrs { + ad, ok := w.pending[a] + if !ok { + todial = append(todial, a) + continue + } + + if ad.conn != nil { + // dial to this addr was successful, complete the request + req.resch <- dialResponse{conn: ad.conn} + continue loop + } + + if ad.err != nil { + // dial to this addr errored, accumulate the error + pr.err.recordErr(a, ad.err) + delete(pr.addrs, a) + continue + } + + // dial is still pending, add to the join list + tojoin = append(tojoin, ad) + } + + if len(todial) == 0 && len(tojoin) == 0 { + // all request applicable addrs have been dialed, we must have errored + req.resch <- dialResponse{err: pr.err} + continue loop + } + + // the request has some pending or new dials, track it and schedule new dials + w.reqno++ + w.requests[w.reqno] = pr + + for _, ad := range tojoin { + if !ad.dialed { + if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { + if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { + ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) + } + } + } + ad.requests = append(ad.requests, w.reqno) + } + + if len(todial) > 0 { + for _, a := range todial { + w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}} + } + + w.nextDial = append(w.nextDial, todial...) + w.nextDial = w.rankAddrs(w.nextDial) + + // trigger a new dial now to account for the new addrs we added + w.triggerDial = ready + } + + case <-w.triggerDial: + for _, addr := range w.nextDial { + // spawn the dial + ad := w.pending[addr] + err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch) + if err != nil { + w.dispatchError(ad, err) + } + } + + w.nextDial = nil + w.triggerDial = nil + + case res := <-w.resch: + if res.Conn != nil { + w.connected = true + } + + ad := w.pending[res.Addr] + + if res.Conn != nil { + // we got a connection, add it to the swarm + conn, err := w.s.addConn(res.Conn, network.DirOutbound) + if err != nil { + // oops no, we failed to add it to the swarm + res.Conn.Close() + w.dispatchError(ad, err) + continue loop + } + + // dispatch to still pending requests + for _, reqno := range ad.requests { + pr, ok := w.requests[reqno] + if !ok { + // it has already dispatched a connection + continue + } + + pr.req.resch <- dialResponse{conn: conn} + delete(w.requests, reqno) + } + + ad.conn = conn + ad.requests = nil + + continue loop + } + + // it must be an error -- add backoff if applicable and dispatch + if res.Err != context.Canceled && !w.connected { + // we only add backoff if there has not been a successful connection + // for consistency with the old dialer behavior. + w.s.backf.AddBackoff(w.peer, res.Addr) + } + + w.dispatchError(ad, res.Err) + } + } +} + +// dispatches an error to a specific addr dial +func (w *dialWorker) dispatchError(ad *addrDial, err error) { + ad.err = err + for _, reqno := range ad.requests { + pr, ok := w.requests[reqno] + if !ok { + // has already been dispatched + continue + } + + // accumulate the error + pr.err.recordErr(ad.addr, err) + + delete(pr.addrs, ad.addr) + if len(pr.addrs) == 0 { + // all addrs have erred, dispatch dial error + // but first do a last one check in case an acceptable connection has landed from + // a simultaneous dial that started later and added new acceptable addrs + c := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer) + if c != nil { + pr.req.resch <- dialResponse{conn: c} + } else { + pr.req.resch <- dialResponse{err: pr.err} + } + delete(w.requests, reqno) + } + } + + ad.requests = nil + + // if it was a backoff, clear the address dial so that it doesn't inhibit new dial requests. + // this is necessary to support active listen scenarios, where a new dial comes in while + // another dial is in progress, and needs to do a direct connection without inhibitions from + // dial backoff. + // it is also necessary to preserve consisent behaviour with the old dialer -- TestDialBackoff + // regresses without this. + if err == ErrDialBackoff { + delete(w.pending, ad.addr) + } +} + +// ranks addresses in descending order of preference for dialing, with the following rules: +// NonRelay > Relay +// NonWS > WS +// Private > Public +// UDP > TCP +func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { + addrTier := func(a ma.Multiaddr) (tier int) { + if isRelayAddr(a) { + tier |= 0b1000 + } + if isExpensiveAddr(a) { + tier |= 0b0100 + } + if !manet.IsPrivateAddr(a) { + tier |= 0b0010 + } + if isFdConsumingAddr(a) { + tier |= 0b0001 + } + + return tier + } + + tiers := make([][]ma.Multiaddr, 16) + for _, a := range addrs { + tier := addrTier(a) + tiers[tier] = append(tiers[tier], a) + } + + result := make([]ma.Multiaddr, 0, len(addrs)) + for _, tier := range tiers { + result = append(result, tier...) + } + + return result +} diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go new file mode 100644 index 0000000000..d9aa115e2f --- /dev/null +++ b/p2p/net/swarm/dial_worker_test.go @@ -0,0 +1,327 @@ +package swarm + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + csms "github.com/libp2p/go-conn-security-multistream" + "github.com/libp2p/go-libp2p-core/peerstore" + "github.com/libp2p/go-libp2p-core/sec/insecure" + "github.com/libp2p/go-libp2p-core/transport" + "github.com/libp2p/go-libp2p-peerstore/pstoremem" + quic "github.com/libp2p/go-libp2p-quic-transport" + tnet "github.com/libp2p/go-libp2p-testing/net" + tptu "github.com/libp2p/go-libp2p-transport-upgrader" + yamux "github.com/libp2p/go-libp2p-yamux" + msmux "github.com/libp2p/go-stream-muxer-multistream" + tcp "github.com/libp2p/go-tcp-transport" + ma "github.com/multiformats/go-multiaddr" +) + +func makeSwarm(t *testing.T) *Swarm { + p := tnet.RandPeerNetParamsOrFatal(t) + + ps, err := pstoremem.NewPeerstore() + require.NoError(t, err) + ps.AddPubKey(p.ID, p.PubKey) + ps.AddPrivKey(p.ID, p.PrivKey) + t.Cleanup(func() { ps.Close() }) + + s, err := NewSwarm(p.ID, ps, WithDialTimeout(time.Second)) + require.NoError(t, err) + + upgrader := makeUpgrader(t, s) + + var tcpOpts []tcp.Option + tcpOpts = append(tcpOpts, tcp.DisableReuseport()) + tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...) + require.NoError(t, err) + if err := s.AddTransport(tcpTransport); err != nil { + t.Fatal(err) + } + if err := s.Listen(p.Addr); err != nil { + t.Fatal(err) + } + + quicTransport, err := quic.NewTransport(p.PrivKey, nil, nil, nil) + if err != nil { + t.Fatal(err) + } + if err := s.AddTransport(quicTransport); err != nil { + t.Fatal(err) + } + if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic")); err != nil { + t.Fatal(err) + } + + return s +} + +func makeUpgrader(t *testing.T, n *Swarm) transport.Upgrader { + id := n.LocalPeer() + pk := n.Peerstore().PrivKey(id) + secMuxer := new(csms.SSMuxer) + secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, pk)) + + stMuxer := msmux.NewBlankTransport() + stMuxer.AddTransport("/yamux/1.0.0", yamux.DefaultTransport) + u, err := tptu.New(secMuxer, stMuxer) + require.NoError(t, err) + return u +} + +func TestDialWorkerLoopBasic(t *testing.T) { + s1 := makeSwarm(t) + s2 := makeSwarm(t) + defer s1.Close() + defer s2.Close() + + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) + + reqch := make(chan dialRequest) + resch := make(chan dialResponse) + worker := newDialWorker(s1, s2.LocalPeer(), reqch) + go worker.loop() + + var conn *Conn + reqch <- dialRequest{ctx: context.Background(), resch: resch} + select { + case res := <-resch: + require.NoError(t, res.err) + conn = res.conn + case <-time.After(time.Minute): + t.Fatal("dial didn't complete") + } + + s, err := conn.NewStream(context.Background()) + require.NoError(t, err) + s.Close() + + var conn2 *Conn + reqch <- dialRequest{ctx: context.Background(), resch: resch} + select { + case res := <-resch: + require.NoError(t, res.err) + conn2 = res.conn + case <-time.After(time.Minute): + t.Fatal("dial didn't complete") + } + + require.Equal(t, conn, conn2) + + close(reqch) + worker.wg.Wait() +} + +func TestDialWorkerLoopConcurrent(t *testing.T) { + s1 := makeSwarm(t) + s2 := makeSwarm(t) + defer s1.Close() + defer s2.Close() + + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) + + reqch := make(chan dialRequest) + worker := newDialWorker(s1, s2.LocalPeer(), reqch) + go worker.loop() + + const dials = 100 + var wg sync.WaitGroup + resch := make(chan dialResponse, dials) + for i := 0; i < dials; i++ { + wg.Add(1) + go func() { + defer wg.Done() + reschgo := make(chan dialResponse, 1) + reqch <- dialRequest{ctx: context.Background(), resch: reschgo} + select { + case res := <-reschgo: + resch <- res + case <-time.After(time.Minute): + resch <- dialResponse{err: errors.New("timed out!")} + } + }() + } + wg.Wait() + + for i := 0; i < dials; i++ { + res := <-resch + require.NoError(t, res.err) + } + + t.Log("all concurrent dials done") + + close(reqch) + worker.wg.Wait() +} + +func TestDialWorkerLoopFailure(t *testing.T) { + s1 := makeSwarm(t) + defer s1.Close() + + p2 := tnet.RandPeerNetParamsOrFatal(t) + + s1.Peerstore().AddAddrs(p2.ID, []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL) + + reqch := make(chan dialRequest) + resch := make(chan dialResponse) + worker := newDialWorker(s1, p2.ID, reqch) + go worker.loop() + + reqch <- dialRequest{ctx: context.Background(), resch: resch} + select { + case res := <-resch: + require.Error(t, res.err) + case <-time.After(time.Minute): + t.Fatal("dial didn't complete") + } + + close(reqch) + worker.wg.Wait() +} + +func TestDialWorkerLoopConcurrentFailure(t *testing.T) { + s1 := makeSwarm(t) + defer s1.Close() + + p2 := tnet.RandPeerNetParamsOrFatal(t) + + s1.Peerstore().AddAddrs(p2.ID, []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL) + + reqch := make(chan dialRequest) + worker := newDialWorker(s1, p2.ID, reqch) + go worker.loop() + + const dials = 100 + var errTimeout = errors.New("timed out!") + var wg sync.WaitGroup + resch := make(chan dialResponse, dials) + for i := 0; i < dials; i++ { + wg.Add(1) + go func() { + defer wg.Done() + reschgo := make(chan dialResponse, 1) + reqch <- dialRequest{ctx: context.Background(), resch: reschgo} + + select { + case res := <-reschgo: + resch <- res + case <-time.After(time.Minute): + resch <- dialResponse{err: errTimeout} + } + }() + } + wg.Wait() + + for i := 0; i < dials; i++ { + res := <-resch + require.Error(t, res.err) + if res.err == errTimeout { + t.Fatal("dial response timed out") + } + } + + t.Log("all concurrent dials done") + + close(reqch) + worker.wg.Wait() +} + +func TestDialWorkerLoopConcurrentMix(t *testing.T) { + s1 := makeSwarm(t) + s2 := makeSwarm(t) + defer s1.Close() + defer s2.Close() + + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) + s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL) + + reqch := make(chan dialRequest) + worker := newDialWorker(s1, s2.LocalPeer(), reqch) + go worker.loop() + + const dials = 100 + var wg sync.WaitGroup + resch := make(chan dialResponse, dials) + for i := 0; i < dials; i++ { + wg.Add(1) + go func() { + defer wg.Done() + reschgo := make(chan dialResponse, 1) + reqch <- dialRequest{ctx: context.Background(), resch: reschgo} + select { + case res := <-reschgo: + resch <- res + case <-time.After(time.Minute): + resch <- dialResponse{err: errors.New("timed out!")} + } + }() + } + wg.Wait() + + for i := 0; i < dials; i++ { + res := <-resch + require.NoError(t, res.err) + } + + t.Log("all concurrent dials done") + + close(reqch) + worker.wg.Wait() +} + +func TestDialWorkerLoopConcurrentFailureStress(t *testing.T) { + s1 := makeSwarm(t) + defer s1.Close() + + p2 := tnet.RandPeerNetParamsOrFatal(t) + + var addrs []ma.Multiaddr + for i := 0; i < 200; i++ { + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/11.0.0.%d/tcp/%d", i%256, 1234+i))) + } + s1.Peerstore().AddAddrs(p2.ID, addrs, peerstore.PermanentAddrTTL) + + reqch := make(chan dialRequest) + worker := newDialWorker(s1, p2.ID, reqch) + go worker.loop() + + const dials = 100 + var errTimeout = errors.New("timed out!") + var wg sync.WaitGroup + resch := make(chan dialResponse, dials) + for i := 0; i < dials; i++ { + wg.Add(1) + go func() { + defer wg.Done() + reschgo := make(chan dialResponse, 1) + reqch <- dialRequest{ctx: context.Background(), resch: reschgo} + select { + case res := <-reschgo: + resch <- res + case <-time.After(5 * time.Minute): + resch <- dialResponse{err: errTimeout} + } + }() + } + wg.Wait() + + for i := 0; i < dials; i++ { + res := <-resch + require.Error(t, res.err) + if res.err == errTimeout { + t.Fatal("dial response timed out") + } + } + + t.Log("all concurrent dials done") + + close(reqch) + worker.wg.Wait() +} diff --git a/p2p/net/swarm/limiter.go b/p2p/net/swarm/limiter.go new file mode 100644 index 0000000000..6b49d8ec0b --- /dev/null +++ b/p2p/net/swarm/limiter.go @@ -0,0 +1,227 @@ +package swarm + +import ( + "context" + "os" + "strconv" + "sync" + "time" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/transport" + + ma "github.com/multiformats/go-multiaddr" +) + +type dialResult struct { + Conn transport.CapableConn + Addr ma.Multiaddr + Err error +} + +type dialJob struct { + addr ma.Multiaddr + peer peer.ID + ctx context.Context + resp chan dialResult + timeout time.Duration +} + +func (dj *dialJob) cancelled() bool { + return dj.ctx.Err() != nil +} + +type dialLimiter struct { + lk sync.Mutex + + fdConsuming int + fdLimit int + waitingOnFd []*dialJob + + dialFunc dialfunc + + activePerPeer map[peer.ID]int + perPeerLimit int + waitingOnPeerLimit map[peer.ID][]*dialJob +} + +type dialfunc func(context.Context, peer.ID, ma.Multiaddr) (transport.CapableConn, error) + +func newDialLimiter(df dialfunc) *dialLimiter { + fd := ConcurrentFdDials + if env := os.Getenv("LIBP2P_SWARM_FD_LIMIT"); env != "" { + if n, err := strconv.ParseInt(env, 10, 32); err == nil { + fd = int(n) + } + } + return newDialLimiterWithParams(df, fd, DefaultPerPeerRateLimit) +} + +func newDialLimiterWithParams(df dialfunc, fdLimit, perPeerLimit int) *dialLimiter { + return &dialLimiter{ + fdLimit: fdLimit, + perPeerLimit: perPeerLimit, + waitingOnPeerLimit: make(map[peer.ID][]*dialJob), + activePerPeer: make(map[peer.ID]int), + dialFunc: df, + } +} + +// freeFDToken frees FD token and if there are any schedules another waiting dialJob +// in it's place +func (dl *dialLimiter) freeFDToken() { + log.Debugf("[limiter] freeing FD token; waiting: %d; consuming: %d", len(dl.waitingOnFd), dl.fdConsuming) + dl.fdConsuming-- + + for len(dl.waitingOnFd) > 0 { + next := dl.waitingOnFd[0] + dl.waitingOnFd[0] = nil // clear out memory + dl.waitingOnFd = dl.waitingOnFd[1:] + + if len(dl.waitingOnFd) == 0 { + // clear out memory. + dl.waitingOnFd = nil + } + + // Skip over canceled dials instead of queuing up a goroutine. + if next.cancelled() { + dl.freePeerToken(next) + continue + } + dl.fdConsuming++ + + // we already have activePerPeer token at this point so we can just dial + go dl.executeDial(next) + return + } +} + +func (dl *dialLimiter) freePeerToken(dj *dialJob) { + log.Debugf("[limiter] freeing peer token; peer %s; addr: %s; active for peer: %d; waiting on peer limit: %d", + dj.peer, dj.addr, dl.activePerPeer[dj.peer], len(dl.waitingOnPeerLimit[dj.peer])) + // release tokens in reverse order than we take them + dl.activePerPeer[dj.peer]-- + if dl.activePerPeer[dj.peer] == 0 { + delete(dl.activePerPeer, dj.peer) + } + + waitlist := dl.waitingOnPeerLimit[dj.peer] + for len(waitlist) > 0 { + next := waitlist[0] + waitlist[0] = nil // clear out memory + waitlist = waitlist[1:] + + if len(waitlist) == 0 { + delete(dl.waitingOnPeerLimit, next.peer) + } else { + dl.waitingOnPeerLimit[next.peer] = waitlist + } + + if next.cancelled() { + continue + } + + dl.activePerPeer[next.peer]++ // just kidding, we still want this token + + dl.addCheckFdLimit(next) + return + } +} + +func (dl *dialLimiter) finishedDial(dj *dialJob) { + dl.lk.Lock() + defer dl.lk.Unlock() + if dl.shouldConsumeFd(dj.addr) { + dl.freeFDToken() + } + + dl.freePeerToken(dj) +} + +func (dl *dialLimiter) shouldConsumeFd(addr ma.Multiaddr) bool { + // we don't consume FD's for relay addresses for now as they will be consumed when the Relay Transport + // actually dials the Relay server. That dial call will also pass through this limiter with + // the address of the relay server i.e. non-relay address. + _, err := addr.ValueForProtocol(ma.P_CIRCUIT) + + isRelay := err == nil + + return !isRelay && isFdConsumingAddr(addr) +} + +func (dl *dialLimiter) addCheckFdLimit(dj *dialJob) { + if dl.shouldConsumeFd(dj.addr) { + if dl.fdConsuming >= dl.fdLimit { + log.Debugf("[limiter] blocked dial waiting on FD token; peer: %s; addr: %s; consuming: %d; "+ + "limit: %d; waiting: %d", dj.peer, dj.addr, dl.fdConsuming, dl.fdLimit, len(dl.waitingOnFd)) + dl.waitingOnFd = append(dl.waitingOnFd, dj) + return + } + + log.Debugf("[limiter] taking FD token: peer: %s; addr: %s; prev consuming: %d", + dj.peer, dj.addr, dl.fdConsuming) + // take token + dl.fdConsuming++ + } + + log.Debugf("[limiter] executing dial; peer: %s; addr: %s; FD consuming: %d; waiting: %d", + dj.peer, dj.addr, dl.fdConsuming, len(dl.waitingOnFd)) + go dl.executeDial(dj) +} + +func (dl *dialLimiter) addCheckPeerLimit(dj *dialJob) { + if dl.activePerPeer[dj.peer] >= dl.perPeerLimit { + log.Debugf("[limiter] blocked dial waiting on peer limit; peer: %s; addr: %s; active: %d; "+ + "peer limit: %d; waiting: %d", dj.peer, dj.addr, dl.activePerPeer[dj.peer], dl.perPeerLimit, + len(dl.waitingOnPeerLimit[dj.peer])) + wlist := dl.waitingOnPeerLimit[dj.peer] + dl.waitingOnPeerLimit[dj.peer] = append(wlist, dj) + return + } + dl.activePerPeer[dj.peer]++ + + dl.addCheckFdLimit(dj) +} + +// AddDialJob tries to take the needed tokens for starting the given dial job. +// If it acquires all needed tokens, it immediately starts the dial, otherwise +// it will put it on the waitlist for the requested token. +func (dl *dialLimiter) AddDialJob(dj *dialJob) { + dl.lk.Lock() + defer dl.lk.Unlock() + + log.Debugf("[limiter] adding a dial job through limiter: %v", dj.addr) + dl.addCheckPeerLimit(dj) +} + +func (dl *dialLimiter) clearAllPeerDials(p peer.ID) { + dl.lk.Lock() + defer dl.lk.Unlock() + delete(dl.waitingOnPeerLimit, p) + log.Debugf("[limiter] clearing all peer dials: %v", p) + // NB: the waitingOnFd list doesn't need to be cleaned out here, we will + // remove them as we encounter them because they are 'cancelled' at this + // point +} + +// executeDial calls the dialFunc, and reports the result through the response +// channel when finished. Once the response is sent it also releases all tokens +// it held during the dial. +func (dl *dialLimiter) executeDial(j *dialJob) { + defer dl.finishedDial(j) + if j.cancelled() { + return + } + + dctx, cancel := context.WithTimeout(j.ctx, j.timeout) + defer cancel() + + con, err := dl.dialFunc(dctx, j.peer, j.addr) + select { + case j.resp <- dialResult{Conn: con, Addr: j.addr, Err: err}: + case <-j.ctx.Done(): + if con != nil { + con.Close() + } + } +} diff --git a/p2p/net/swarm/limiter_test.go b/p2p/net/swarm/limiter_test.go new file mode 100644 index 0000000000..47918ce565 --- /dev/null +++ b/p2p/net/swarm/limiter_test.go @@ -0,0 +1,397 @@ +package swarm + +import ( + "context" + "errors" + "fmt" + "math/rand" + "strconv" + "sync" + "testing" + "time" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/test" + "github.com/libp2p/go-libp2p-core/transport" + + ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" +) + +func addrWithPort(p int) ma.Multiaddr { + return ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", p)) +} + +// in these tests I use addresses with tcp ports over a certain number to +// signify 'good' addresses that will succeed, and addresses below that number +// will fail. This lets us more easily test these different scenarios. +func tcpPortOver(a ma.Multiaddr, n int) bool { + port, err := a.ValueForProtocol(ma.P_TCP) + if err != nil { + panic(err) + } + + pnum, err := strconv.Atoi(port) + if err != nil { + panic(err) + } + + return pnum > n +} + +func tryDialAddrs(ctx context.Context, l *dialLimiter, p peer.ID, addrs []ma.Multiaddr, res chan dialResult) { + for _, a := range addrs { + l.AddDialJob(&dialJob{ + ctx: ctx, + peer: p, + addr: a, + resp: res, + }) + } +} + +func hangDialFunc(hang chan struct{}) dialfunc { + return func(ctx context.Context, p peer.ID, a ma.Multiaddr) (transport.CapableConn, error) { + if mafmt.UTP.Matches(a) { + return transport.CapableConn(nil), nil + } + + _, err := a.ValueForProtocol(ma.P_CIRCUIT) + if err == nil { + return transport.CapableConn(nil), nil + } + + if tcpPortOver(a, 10) { + return transport.CapableConn(nil), nil + } + + <-hang + return nil, fmt.Errorf("test bad dial") + } +} + +func TestLimiterBasicDials(t *testing.T) { + hang := make(chan struct{}) + defer close(hang) + + l := newDialLimiterWithParams(hangDialFunc(hang), ConcurrentFdDials, 4) + + bads := []ma.Multiaddr{addrWithPort(1), addrWithPort(2), addrWithPort(3), addrWithPort(4)} + good := addrWithPort(20) + + resch := make(chan dialResult) + pid := peer.ID("testpeer") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tryDialAddrs(ctx, l, pid, bads, resch) + + l.AddDialJob(&dialJob{ + ctx: ctx, + peer: pid, + addr: good, + resp: resch, + }) + + select { + case <-resch: + t.Fatal("no dials should have completed!") + case <-time.After(time.Millisecond * 100): + } + + // complete a single hung dial + hang <- struct{}{} + + select { + case r := <-resch: + if r.Err == nil { + t.Fatal("should have gotten failed dial result") + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for dial completion") + } + + select { + case r := <-resch: + if r.Err != nil { + t.Fatal("expected second result to be success!") + } + case <-time.After(time.Second): + } +} + +func TestFDLimiting(t *testing.T) { + hang := make(chan struct{}) + defer close(hang) + l := newDialLimiterWithParams(hangDialFunc(hang), 16, 5) + + bads := []ma.Multiaddr{addrWithPort(1), addrWithPort(2), addrWithPort(3), addrWithPort(4)} + pids := []peer.ID{"testpeer1", "testpeer2", "testpeer3", "testpeer4"} + goodTCP := addrWithPort(20) + + ctx := context.Background() + resch := make(chan dialResult) + + // take all fd limit tokens with hang dials + for _, pid := range pids { + tryDialAddrs(ctx, l, pid, bads, resch) + } + + // these dials should work normally, but will hang because we have taken + // up all the fd limiting + for _, pid := range pids { + l.AddDialJob(&dialJob{ + ctx: ctx, + peer: pid, + addr: goodTCP, + resp: resch, + }) + } + + select { + case <-resch: + t.Fatal("no dials should have completed!") + case <-time.After(time.Millisecond * 100): + } + + pid5 := peer.ID("testpeer5") + utpaddr := ma.StringCast("/ip4/127.0.0.1/udp/7777/utp") + + // This should complete immediately since utp addresses arent blocked by fd rate limiting + l.AddDialJob(&dialJob{ctx: ctx, peer: pid5, addr: utpaddr, resp: resch}) + + select { + case res := <-resch: + if res.Err != nil { + t.Fatal("should have gotten successful response") + } + case <-time.After(time.Second * 5): + t.Fatal("timeout waiting for utp addr success") + } + + // A relay address with tcp transport will complete because we do not consume fds for dials + // with relay addresses as the fd will be consumed when we actually dial the relay server. + pid6 := test.RandPeerIDFatal(t) + relayAddr := ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/20/p2p-circuit/p2p/%s", pid6)) + l.AddDialJob(&dialJob{ctx: ctx, peer: pid6, addr: relayAddr, resp: resch}) + + select { + case res := <-resch: + if res.Err != nil { + t.Fatal("should have gotten successful response") + } + case <-time.After(time.Second * 5): + t.Fatal("timeout waiting for relay addr success") + } +} + +func TestTokenRedistribution(t *testing.T) { + var lk sync.Mutex + hangchs := make(map[peer.ID]chan struct{}) + df := func(ctx context.Context, p peer.ID, a ma.Multiaddr) (transport.CapableConn, error) { + if tcpPortOver(a, 10) { + return (transport.CapableConn)(nil), nil + } + + lk.Lock() + ch := hangchs[p] + lk.Unlock() + <-ch + return nil, fmt.Errorf("test bad dial") + } + l := newDialLimiterWithParams(df, 8, 4) + + bads := []ma.Multiaddr{addrWithPort(1), addrWithPort(2), addrWithPort(3), addrWithPort(4)} + pids := []peer.ID{"testpeer1", "testpeer2"} + + ctx := context.Background() + resch := make(chan dialResult) + + // take all fd limit tokens with hang dials + for _, pid := range pids { + hangchs[pid] = make(chan struct{}) + } + + for _, pid := range pids { + tryDialAddrs(ctx, l, pid, bads, resch) + } + + // add a good dial job for peer 1 + l.AddDialJob(&dialJob{ + ctx: ctx, + peer: pids[1], + addr: ma.StringCast("/ip4/127.0.0.1/tcp/1001"), + resp: resch, + }) + + select { + case <-resch: + t.Fatal("no dials should have completed!") + case <-time.After(time.Millisecond * 100): + } + + // unblock one dial for peer 0 + hangchs[pids[0]] <- struct{}{} + + select { + case res := <-resch: + if res.Err == nil { + t.Fatal("should have only been a failure here") + } + case <-time.After(time.Millisecond * 100): + t.Fatal("expected a dial failure here") + } + + select { + case <-resch: + t.Fatal("no more dials should have completed!") + case <-time.After(time.Millisecond * 100): + } + + // add a bad dial job to peer 0 to fill their rate limiter + // and test that more dials for this peer won't interfere with peer 1's successful dial incoming + l.AddDialJob(&dialJob{ + ctx: ctx, + peer: pids[0], + addr: addrWithPort(7), + resp: resch, + }) + + hangchs[pids[1]] <- struct{}{} + + // now one failed dial from peer 1 should get through and fail + // which will in turn unblock the successful dial on peer 1 + select { + case res := <-resch: + if res.Err == nil { + t.Fatal("should have only been a failure here") + } + case <-time.After(time.Millisecond * 100): + t.Fatal("expected a dial failure here") + } + + select { + case res := <-resch: + if res.Err != nil { + t.Fatal("should have succeeded!") + } + case <-time.After(time.Millisecond * 100): + t.Fatal("should have gotten successful dial") + } +} + +func TestStressLimiter(t *testing.T) { + df := func(ctx context.Context, p peer.ID, a ma.Multiaddr) (transport.CapableConn, error) { + if tcpPortOver(a, 1000) { + return transport.CapableConn(nil), nil + } + + time.Sleep(time.Millisecond * time.Duration(5+rand.Intn(100))) + return nil, fmt.Errorf("test bad dial") + } + + l := newDialLimiterWithParams(df, 20, 5) + + var bads []ma.Multiaddr + for i := 0; i < 100; i++ { + bads = append(bads, addrWithPort(i)) + } + + addresses := append(bads, addrWithPort(2000)) + success := make(chan struct{}) + + for i := 0; i < 20; i++ { + go func(id peer.ID) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resp := make(chan dialResult) + time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond) + for _, i := range rand.Perm(len(addresses)) { + l.AddDialJob(&dialJob{ + addr: addresses[i], + ctx: ctx, + peer: id, + resp: resp, + }) + } + + for res := range resp { + if res.Err == nil { + success <- struct{}{} + return + } + } + }(peer.ID(fmt.Sprintf("testpeer%d", i))) + } + + for i := 0; i < 20; i++ { + select { + case <-success: + case <-time.After(time.Minute): + t.Fatal("expected a success within five seconds") + } + } +} + +func TestFDLimitUnderflow(t *testing.T) { + df := func(ctx context.Context, p peer.ID, addr ma.Multiaddr) (transport.CapableConn, error) { + select { + case <-ctx.Done(): + case <-time.After(5 * time.Second): + } + return nil, fmt.Errorf("df timed out") + } + + const fdLimit = 20 + l := newDialLimiterWithParams(df, fdLimit, 3) + + var addrs []ma.Multiaddr + for i := 0; i <= 1000; i++ { + addrs = append(addrs, addrWithPort(i)) + } + + wg := sync.WaitGroup{} + const num = 3 * fdLimit + wg.Add(num) + errs := make(chan error, num) + for i := 0; i < num; i++ { + go func(id peer.ID, i int) { + defer wg.Done() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resp := make(chan dialResult) + l.AddDialJob(&dialJob{ + addr: addrs[i], + ctx: ctx, + peer: id, + resp: resp, + }) + + for res := range resp { + if res.Err != nil { + return + } + errs <- errors.New("got dial res, but shouldn't") + } + }(peer.ID(fmt.Sprintf("testpeer%d", i%20)), i) + } + + go func() { + wg.Wait() + close(errs) + }() + + for err := range errs { + t.Fatal(err) + } + + l.lk.Lock() + fdConsuming := l.fdConsuming + l.lk.Unlock() + + if fdConsuming < 0 { + t.Fatalf("l.fdConsuming < 0") + } +} diff --git a/p2p/net/swarm/peers_test.go b/p2p/net/swarm/peers_test.go new file mode 100644 index 0000000000..a2d8f820e2 --- /dev/null +++ b/p2p/net/swarm/peers_test.go @@ -0,0 +1,65 @@ +package swarm_test + +import ( + "context" + "testing" + "time" + + . "github.com/libp2p/go-libp2p/p2p/net/swarm" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peerstore" + + ma "github.com/multiformats/go-multiaddr" + + "github.com/stretchr/testify/require" +) + +func TestPeers(t *testing.T) { + ctx := context.Background() + swarms := makeSwarms(t, 2) + s1 := swarms[0] + s2 := swarms[1] + + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + // TODO: make a DialAddr func. + s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) + // t.Logf("connections from %s", s.LocalPeer()) + // for _, c := range s.ConnsToPeer(dst) { + // t.Logf("connection from %s to %s: %v", s.LocalPeer(), dst, c) + // } + // t.Logf("") + if _, err := s.DialPeer(ctx, dst); err != nil { + t.Fatal("error swarm dialing to peer", err) + } + // t.Log(s.swarm.Dump()) + } + + connect(s1, s2.LocalPeer(), s2.ListenAddresses()[0]) + require.Eventually(t, func() bool { return len(s2.Peers()) > 0 }, 3*time.Second, 50*time.Millisecond) + connect(s2, s1.LocalPeer(), s1.ListenAddresses()[0]) + + for i := 0; i < 100; i++ { + connect(s1, s2.LocalPeer(), s2.ListenAddresses()[0]) + connect(s2, s1.LocalPeer(), s1.ListenAddresses()[0]) + } + + for _, s := range swarms { + log.Infof("%s swarm routing table: %s", s.LocalPeer(), s.Peers()) + } + + test := func(s *Swarm) { + expect := 1 + actual := len(s.Peers()) + if actual != expect { + t.Errorf("%s has %d peers, not %d: %v", s.LocalPeer(), actual, expect, s.Peers()) + } + actual = len(s.Conns()) + if actual != expect { + t.Errorf("%s has %d conns, not %d: %v", s.LocalPeer(), actual, expect, s.Conns()) + } + } + + test(s1) + test(s2) +} diff --git a/p2p/net/swarm/simul_test.go b/p2p/net/swarm/simul_test.go new file mode 100644 index 0000000000..5f3d6d34ae --- /dev/null +++ b/p2p/net/swarm/simul_test.go @@ -0,0 +1,78 @@ +package swarm_test + +import ( + "context" + "runtime" + "sync" + "testing" + "time" + + . "github.com/libp2p/go-libp2p/p2p/net/swarm" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peerstore" + + "github.com/libp2p/go-libp2p-testing/ci" + ma "github.com/multiformats/go-multiaddr" +) + +func TestSimultOpen(t *testing.T) { + t.Parallel() + swarms := makeSwarms(t, 2, swarmt.OptDisableReuseport) + + // connect everyone + { + var wg sync.WaitGroup + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + defer wg.Done() + // copy for other peer + log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) + s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) + if _, err := s.DialPeer(context.Background(), dst); err != nil { + t.Error("error swarm dialing to peer", err) + } + } + + log.Info("Connecting swarms simultaneously.") + wg.Add(2) + go connect(swarms[0], swarms[1].LocalPeer(), swarms[1].ListenAddresses()[0]) + go connect(swarms[1], swarms[0].LocalPeer(), swarms[0].ListenAddresses()[0]) + wg.Wait() + } + + for _, s := range swarms { + s.Close() + } +} + +func TestSimultOpenMany(t *testing.T) { + // t.Skip("very very slow") + + addrs := 20 + rounds := 10 + if ci.IsRunning() || runtime.GOOS == "darwin" { + // osx has a limit of 256 file descriptors + addrs = 10 + rounds = 5 + } + subtestSwarm(t, addrs, rounds) +} + +func TestSimultOpenFewStress(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + // t.Skip("skipping for another test") + t.Parallel() + + msgs := 40 + swarms := 2 + rounds := 10 + // rounds := 100 + + for i := 0; i < rounds; i++ { + subtestSwarm(t, swarms, msgs) + <-time.After(10 * time.Millisecond) + } +} diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go new file mode 100644 index 0000000000..9df1840370 --- /dev/null +++ b/p2p/net/swarm/swarm.go @@ -0,0 +1,607 @@ +package swarm + +import ( + "context" + "errors" + "fmt" + "io" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/libp2p/go-libp2p-core/connmgr" + "github.com/libp2p/go-libp2p-core/metrics" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peerstore" + "github.com/libp2p/go-libp2p-core/transport" + + logging "github.com/ipfs/go-log/v2" + ma "github.com/multiformats/go-multiaddr" +) + +const ( + defaultDialTimeout = 15 * time.Second + + // defaultDialTimeoutLocal is the maximum duration a Dial to local network address + // is allowed to take. + // This includes the time between dialing the raw network connection, + // protocol selection as well the handshake, if applicable. + defaultDialTimeoutLocal = 5 * time.Second +) + +var log = logging.Logger("swarm2") + +// ErrSwarmClosed is returned when one attempts to operate on a closed swarm. +var ErrSwarmClosed = errors.New("swarm closed") + +// ErrAddrFiltered is returned when trying to register a connection to a +// filtered address. You shouldn't see this error unless some underlying +// transport is misbehaving. +var ErrAddrFiltered = errors.New("address filtered") + +// ErrDialTimeout is returned when one a dial times out due to the global timeout +var ErrDialTimeout = errors.New("dial timed out") + +type Option func(*Swarm) error + +// WithConnectionGater sets a connection gater +func WithConnectionGater(gater connmgr.ConnectionGater) Option { + return func(s *Swarm) error { + s.gater = gater + return nil + } +} + +// WithMetrics sets a metrics reporter +func WithMetrics(reporter metrics.Reporter) Option { + return func(s *Swarm) error { + s.bwc = reporter + return nil + } +} + +func WithDialTimeout(t time.Duration) Option { + return func(s *Swarm) error { + s.dialTimeout = t + return nil + } +} + +func WithDialTimeoutLocal(t time.Duration) Option { + return func(s *Swarm) error { + s.dialTimeoutLocal = t + return nil + } +} + +func WithResourceManager(m network.ResourceManager) Option { + return func(s *Swarm) error { + s.rcmgr = m + return nil + } +} + +// Swarm is a connection muxer, allowing connections to other peers to +// be opened and closed, while still using the same Chan for all +// communication. The Chan sends/receives Messages, which note the +// destination or source Peer. +type Swarm struct { + nextConnID uint64 // guarded by atomic + nextStreamID uint64 // guarded by atomic + + // Close refcount. This allows us to fully wait for the swarm to be torn + // down before continuing. + refs sync.WaitGroup + + rcmgr network.ResourceManager + + local peer.ID + peers peerstore.Peerstore + + dialTimeout time.Duration + dialTimeoutLocal time.Duration + + conns struct { + sync.RWMutex + m map[peer.ID][]*Conn + } + + listeners struct { + sync.RWMutex + + ifaceListenAddres []ma.Multiaddr + cacheEOL time.Time + + m map[transport.Listener]struct{} + } + + notifs struct { + sync.RWMutex + m map[network.Notifiee]struct{} + } + + transports struct { + sync.RWMutex + m map[int]transport.Transport + } + + // stream handlers + streamh atomic.Value + + // dialing helpers + dsync *dialSync + backf DialBackoff + limiter *dialLimiter + gater connmgr.ConnectionGater + + closeOnce sync.Once + ctx context.Context // is canceled when Close is called + ctxCancel context.CancelFunc + + bwc metrics.Reporter +} + +// NewSwarm constructs a Swarm. +func NewSwarm(local peer.ID, peers peerstore.Peerstore, opts ...Option) (*Swarm, error) { + ctx, cancel := context.WithCancel(context.Background()) + s := &Swarm{ + local: local, + peers: peers, + ctx: ctx, + ctxCancel: cancel, + dialTimeout: defaultDialTimeout, + dialTimeoutLocal: defaultDialTimeoutLocal, + } + + s.conns.m = make(map[peer.ID][]*Conn) + s.listeners.m = make(map[transport.Listener]struct{}) + s.transports.m = make(map[int]transport.Transport) + s.notifs.m = make(map[network.Notifiee]struct{}) + + for _, opt := range opts { + if err := opt(s); err != nil { + return nil, err + } + } + if s.rcmgr == nil { + s.rcmgr = network.NullResourceManager + } + + s.dsync = newDialSync(s.dialWorkerLoop) + s.limiter = newDialLimiter(s.dialAddr) + s.backf.init(s.ctx) + return s, nil +} + +func (s *Swarm) Close() error { + s.closeOnce.Do(s.close) + return nil +} + +func (s *Swarm) close() { + s.ctxCancel() + + // Prevents new connections and/or listeners from being added to the swarm. + s.listeners.Lock() + listeners := s.listeners.m + s.listeners.m = nil + s.listeners.Unlock() + + s.conns.Lock() + conns := s.conns.m + s.conns.m = nil + s.conns.Unlock() + + // Lots of goroutines but we might as well do this in parallel. We want to shut down as fast as + // possible. + + for l := range listeners { + go func(l transport.Listener) { + if err := l.Close(); err != nil { + log.Errorf("error when shutting down listener: %s", err) + } + }(l) + } + + for _, cs := range conns { + for _, c := range cs { + go func(c *Conn) { + if err := c.Close(); err != nil { + log.Errorf("error when shutting down connection: %s", err) + } + }(c) + } + } + + // Wait for everything to finish. + s.refs.Wait() + + // Now close out any transports (if necessary). Do this after closing + // all connections/listeners. + s.transports.Lock() + transports := s.transports.m + s.transports.m = nil + s.transports.Unlock() + + var wg sync.WaitGroup + for _, t := range transports { + if closer, ok := t.(io.Closer); ok { + wg.Add(1) + go func(c io.Closer) { + defer wg.Done() + if err := closer.Close(); err != nil { + log.Errorf("error when closing down transport %T: %s", c, err) + } + }(closer) + } + } + wg.Wait() +} + +func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, error) { + var ( + p = tc.RemotePeer() + addr = tc.RemoteMultiaddr() + ) + + // create the Stat object, initializing with the underlying connection Stat if available + var stat network.ConnStats + if cs, ok := tc.(network.ConnStat); ok { + stat = cs.Stat() + } + stat.Direction = dir + stat.Opened = time.Now() + + // Wrap and register the connection. + c := &Conn{ + conn: tc, + swarm: s, + stat: stat, + id: atomic.AddUint64(&s.nextConnID, 1), + } + + // we ONLY check upgraded connections here so we can send them a Disconnect message. + // If we do this in the Upgrader, we will not be able to do this. + if s.gater != nil { + if allow, _ := s.gater.InterceptUpgraded(c); !allow { + // TODO Send disconnect with reason here + err := tc.Close() + if err != nil { + log.Warnf("failed to close connection with peer %s and addr %s; err: %s", p.Pretty(), addr, err) + } + return nil, ErrGaterDisallowedConnection + } + } + + // Add the public key. + if pk := tc.RemotePublicKey(); pk != nil { + s.peers.AddPubKey(p, pk) + } + + // Clear any backoffs + s.backf.Clear(p) + + // Finally, add the peer. + s.conns.Lock() + // Check if we're still online + if s.conns.m == nil { + s.conns.Unlock() + tc.Close() + return nil, ErrSwarmClosed + } + + c.streams.m = make(map[*Stream]struct{}) + s.conns.m[p] = append(s.conns.m[p], c) + + // Add two swarm refs: + // * One will be decremented after the close notifications fire in Conn.doClose + // * The other will be decremented when Conn.start exits. + s.refs.Add(2) + + // Take the notification lock before releasing the conns lock to block + // Disconnect notifications until after the Connect notifications done. + c.notifyLk.Lock() + s.conns.Unlock() + + s.notifyAll(func(f network.Notifiee) { + f.Connected(s, c) + }) + c.notifyLk.Unlock() + + c.start() + return c, nil +} + +// Peerstore returns this swarms internal Peerstore. +func (s *Swarm) Peerstore() peerstore.Peerstore { + return s.peers +} + +// SetStreamHandler assigns the handler for new streams. +func (s *Swarm) SetStreamHandler(handler network.StreamHandler) { + s.streamh.Store(handler) +} + +// StreamHandler gets the handler for new streams. +func (s *Swarm) StreamHandler() network.StreamHandler { + handler, _ := s.streamh.Load().(network.StreamHandler) + return handler +} + +// NewStream creates a new stream on any available connection to peer, dialing +// if necessary. +func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error) { + log.Debugf("[%s] opening stream to peer [%s]", s.local, p) + + // Algorithm: + // 1. Find the best connection, otherwise, dial. + // 2. Try opening a stream. + // 3. If the underlying connection is, in fact, closed, close the outer + // connection and try again. We do this in case we have a closed + // connection but don't notice it until we actually try to open a + // stream. + // + // Note: We only dial once. + // + // TODO: Try all connections even if we get an error opening a stream on + // a non-closed connection. + dials := 0 + for { + // will prefer direct connections over relayed connections for opening streams + c := s.bestConnToPeer(p) + if c == nil { + if nodial, _ := network.GetNoDial(ctx); nodial { + return nil, network.ErrNoConn + } + + if dials >= DialAttempts { + return nil, errors.New("max dial attempts exceeded") + } + dials++ + + var err error + c, err = s.dialPeer(ctx, p) + if err != nil { + return nil, err + } + } + + s, err := c.NewStream(ctx) + if err != nil { + if c.conn.IsClosed() { + continue + } + return nil, err + } + return s, nil + } +} + +// ConnsToPeer returns all the live connections to peer. +func (s *Swarm) ConnsToPeer(p peer.ID) []network.Conn { + // TODO: Consider sorting the connection list best to worst. Currently, + // it's sorted oldest to newest. + s.conns.RLock() + defer s.conns.RUnlock() + conns := s.conns.m[p] + output := make([]network.Conn, len(conns)) + for i, c := range conns { + output[i] = c + } + return output +} + +func isBetterConn(a, b *Conn) bool { + // If one is transient and not the other, prefer the non-transient connection. + aTransient := a.Stat().Transient + bTransient := b.Stat().Transient + if aTransient != bTransient { + return !aTransient + } + + // If one is direct and not the other, prefer the direct connection. + aDirect := isDirectConn(a) + bDirect := isDirectConn(b) + if aDirect != bDirect { + return aDirect + } + + // Otherwise, prefer the connection with more open streams. + a.streams.Lock() + aLen := len(a.streams.m) + a.streams.Unlock() + + b.streams.Lock() + bLen := len(b.streams.m) + b.streams.Unlock() + + if aLen != bLen { + return aLen > bLen + } + + // finally, pick the last connection. + return true +} + +// bestConnToPeer returns the best connection to peer. +func (s *Swarm) bestConnToPeer(p peer.ID) *Conn { + + // TODO: Prefer some transports over others. + // For now, prefers direct connections over Relayed connections. + // For tie-breaking, select the newest non-closed connection with the most streams. + s.conns.RLock() + defer s.conns.RUnlock() + + var best *Conn + for _, c := range s.conns.m[p] { + if c.conn.IsClosed() { + // We *will* garbage collect this soon anyways. + continue + } + if best == nil || isBetterConn(c, best) { + best = c + } + } + return best +} + +func (s *Swarm) bestAcceptableConnToPeer(ctx context.Context, p peer.ID) *Conn { + conn := s.bestConnToPeer(p) + if conn != nil { + forceDirect, _ := network.GetForceDirectDial(ctx) + if !forceDirect || isDirectConn(conn) { + return conn + } + } + return nil +} + +func isDirectConn(c *Conn) bool { + return c != nil && !c.conn.Transport().Proxy() +} + +// Connectedness returns our "connectedness" state with the given peer. +// +// To check if we have an open connection, use `s.Connectedness(p) == +// network.Connected`. +func (s *Swarm) Connectedness(p peer.ID) network.Connectedness { + if s.bestConnToPeer(p) != nil { + return network.Connected + } + return network.NotConnected +} + +// Conns returns a slice of all connections. +func (s *Swarm) Conns() []network.Conn { + s.conns.RLock() + defer s.conns.RUnlock() + + conns := make([]network.Conn, 0, len(s.conns.m)) + for _, cs := range s.conns.m { + for _, c := range cs { + conns = append(conns, c) + } + } + return conns +} + +// ClosePeer closes all connections to the given peer. +func (s *Swarm) ClosePeer(p peer.ID) error { + conns := s.ConnsToPeer(p) + switch len(conns) { + case 0: + return nil + case 1: + return conns[0].Close() + default: + errCh := make(chan error) + for _, c := range conns { + go func(c network.Conn) { + errCh <- c.Close() + }(c) + } + + var errs []string + for range conns { + err := <-errCh + if err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + return fmt.Errorf("when disconnecting from peer %s: %s", p, strings.Join(errs, ", ")) + } + return nil + } +} + +// Peers returns a copy of the set of peers swarm is connected to. +func (s *Swarm) Peers() []peer.ID { + s.conns.RLock() + defer s.conns.RUnlock() + peers := make([]peer.ID, 0, len(s.conns.m)) + for p := range s.conns.m { + peers = append(peers, p) + } + + return peers +} + +// LocalPeer returns the local peer swarm is associated to. +func (s *Swarm) LocalPeer() peer.ID { + return s.local +} + +// Backoff returns the DialBackoff object for this swarm. +func (s *Swarm) Backoff() *DialBackoff { + return &s.backf +} + +// notifyAll sends a signal to all Notifiees +func (s *Swarm) notifyAll(notify func(network.Notifiee)) { + var wg sync.WaitGroup + + s.notifs.RLock() + wg.Add(len(s.notifs.m)) + for f := range s.notifs.m { + go func(f network.Notifiee) { + defer wg.Done() + notify(f) + }(f) + } + + wg.Wait() + s.notifs.RUnlock() +} + +// Notify signs up Notifiee to receive signals when events happen +func (s *Swarm) Notify(f network.Notifiee) { + s.notifs.Lock() + s.notifs.m[f] = struct{}{} + s.notifs.Unlock() +} + +// StopNotify unregisters Notifiee fromr receiving signals +func (s *Swarm) StopNotify(f network.Notifiee) { + s.notifs.Lock() + delete(s.notifs.m, f) + s.notifs.Unlock() +} + +func (s *Swarm) removeConn(c *Conn) { + p := c.RemotePeer() + + s.conns.Lock() + defer s.conns.Unlock() + cs := s.conns.m[p] + for i, ci := range cs { + if ci == c { + if len(cs) == 1 { + delete(s.conns.m, p) + } else { + // NOTE: We're intentionally preserving order. + // This way, connections to a peer are always + // sorted oldest to newest. + copy(cs[i:], cs[i+1:]) + cs[len(cs)-1] = nil + s.conns.m[p] = cs[:len(cs)-1] + } + return + } + } +} + +// String returns a string representation of Network. +func (s *Swarm) String() string { + return fmt.Sprintf("", s.LocalPeer()) +} + +func (s *Swarm) ResourceManager() network.ResourceManager { + return s.rcmgr +} + +// Swarm is a Network. +var _ network.Network = (*Swarm)(nil) +var _ transport.TransportNetwork = (*Swarm)(nil) diff --git a/p2p/net/swarm/swarm_addr.go b/p2p/net/swarm/swarm_addr.go new file mode 100644 index 0000000000..8d088e76df --- /dev/null +++ b/p2p/net/swarm/swarm_addr.go @@ -0,0 +1,72 @@ +package swarm + +import ( + "time" + + manet "github.com/multiformats/go-multiaddr/net" + + ma "github.com/multiformats/go-multiaddr" +) + +// ListenAddresses returns a list of addresses at which this swarm listens. +func (s *Swarm) ListenAddresses() []ma.Multiaddr { + s.listeners.RLock() + defer s.listeners.RUnlock() + return s.listenAddressesNoLock() +} + +func (s *Swarm) listenAddressesNoLock() []ma.Multiaddr { + addrs := make([]ma.Multiaddr, 0, len(s.listeners.m)) + for l := range s.listeners.m { + addrs = append(addrs, l.Multiaddr()) + } + return addrs +} + +const ifaceAddrsCacheDuration = 1 * time.Minute + +// InterfaceListenAddresses returns a list of addresses at which this swarm +// listens. It expands "any interface" addresses (/ip4/0.0.0.0, /ip6/::) to +// use the known local interfaces. +func (s *Swarm) InterfaceListenAddresses() ([]ma.Multiaddr, error) { + s.listeners.RLock() // RLock start + + ifaceListenAddres := s.listeners.ifaceListenAddres + isEOL := time.Now().After(s.listeners.cacheEOL) + s.listeners.RUnlock() // RLock end + + if !isEOL { + // Cache is valid, clone the slice + return append(ifaceListenAddres[:0:0], ifaceListenAddres...), nil + } + + // Cache is not valid + // Perfrom double checked locking + + s.listeners.Lock() // Lock start + + ifaceListenAddres = s.listeners.ifaceListenAddres + isEOL = time.Now().After(s.listeners.cacheEOL) + if isEOL { + // Cache is still invalid + listenAddres := s.listenAddressesNoLock() + if len(listenAddres) > 0 { + // We're actually listening on addresses. + var err error + ifaceListenAddres, err = manet.ResolveUnspecifiedAddresses(listenAddres, nil) + if err != nil { + s.listeners.Unlock() // Lock early exit + return nil, err + } + } else { + ifaceListenAddres = nil + } + + s.listeners.ifaceListenAddres = ifaceListenAddres + s.listeners.cacheEOL = time.Now().Add(ifaceAddrsCacheDuration) + } + + s.listeners.Unlock() // Lock end + + return append(ifaceListenAddres[:0:0], ifaceListenAddres...), nil +} diff --git a/p2p/net/swarm/swarm_addr_test.go b/p2p/net/swarm/swarm_addr_test.go new file mode 100644 index 0000000000..2bdc1ba3ae --- /dev/null +++ b/p2p/net/swarm/swarm_addr_test.go @@ -0,0 +1,59 @@ +package swarm_test + +import ( + "context" + "testing" + + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + + "github.com/libp2p/go-libp2p-core/peerstore" + "github.com/libp2p/go-libp2p-core/test" + + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func TestDialBadAddrs(t *testing.T) { + m := func(s string) ma.Multiaddr { + maddr, err := ma.NewMultiaddr(s) + if err != nil { + t.Fatal(err) + } + return maddr + } + + s := makeSwarms(t, 1)[0] + + test := func(a ma.Multiaddr) { + p := test.RandPeerIDFatal(t) + s.Peerstore().AddAddr(p, a, peerstore.PermanentAddrTTL) + if _, err := s.DialPeer(context.Background(), p); err == nil { + t.Errorf("swarm should not dial: %s", p) + } + } + + test(m("/ip6/fe80::1")) // link local + test(m("/ip6/fe80::100")) // link local + test(m("/ip4/127.0.0.1/udp/1234/utp")) // utp +} + +func TestAddrRace(t *testing.T) { + s := makeSwarms(t, 1)[0] + defer s.Close() + + a1, err := s.InterfaceListenAddresses() + require.NoError(t, err) + a2, err := s.InterfaceListenAddresses() + require.NoError(t, err) + + if len(a1) > 0 && len(a2) > 0 && &a1[0] == &a2[0] { + t.Fatal("got the exact same address set twice; this could lead to data races") + } +} + +func TestAddressesWithoutListening(t *testing.T) { + s := swarmt.GenSwarm(t, swarmt.OptDialOnly) + a1, err := s.InterfaceListenAddresses() + require.NoError(t, err) + require.Empty(t, a1, "expected to be listening on no addresses") +} diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go new file mode 100644 index 0000000000..77b2defb52 --- /dev/null +++ b/p2p/net/swarm/swarm_conn.go @@ -0,0 +1,263 @@ +package swarm + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + ic "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/transport" + + ma "github.com/multiformats/go-multiaddr" +) + +// TODO: Put this elsewhere. + +// ErrConnClosed is returned when operating on a closed connection. +var ErrConnClosed = errors.New("connection closed") + +// Conn is the connection type used by swarm. In general, you won't use this +// type directly. +type Conn struct { + id uint64 + conn transport.CapableConn + swarm *Swarm + + closeOnce sync.Once + err error + + notifyLk sync.Mutex + + streams struct { + sync.Mutex + m map[*Stream]struct{} + } + + stat network.ConnStats +} + +var _ network.Conn = &Conn{} + +func (c *Conn) ID() string { + // format: - + return fmt.Sprintf("%s-%d", c.RemotePeer().Pretty()[0:10], c.id) +} + +// Close closes this connection. +// +// Note: This method won't wait for the close notifications to finish as that +// would create a deadlock when called from an open notification (because all +// open notifications must finish before we can fire off the close +// notifications). +func (c *Conn) Close() error { + c.closeOnce.Do(c.doClose) + return c.err +} + +func (c *Conn) doClose() { + c.swarm.removeConn(c) + + // Prevent new streams from opening. + c.streams.Lock() + streams := c.streams.m + c.streams.m = nil + c.streams.Unlock() + + c.err = c.conn.Close() + + // This is just for cleaning up state. The connection has already been closed. + // We *could* optimize this but it really isn't worth it. + for s := range streams { + s.Reset() + } + + // do this in a goroutine to avoid deadlocking if we call close in an open notification. + go func() { + // prevents us from issuing close notifications before finishing the open notifications + c.notifyLk.Lock() + defer c.notifyLk.Unlock() + + c.swarm.notifyAll(func(f network.Notifiee) { + f.Disconnected(c.swarm, c) + }) + c.swarm.refs.Done() // taken in Swarm.addConn + }() +} + +func (c *Conn) removeStream(s *Stream) { + c.streams.Lock() + c.stat.NumStreams-- + delete(c.streams.m, s) + c.streams.Unlock() + s.scope.Done() +} + +// listens for new streams. +// +// The caller must take a swarm ref before calling. This function decrements the +// swarm ref count. +func (c *Conn) start() { + go func() { + defer c.swarm.refs.Done() + defer c.Close() + + for { + ts, err := c.conn.AcceptStream() + if err != nil { + return + } + scope, err := c.swarm.ResourceManager().OpenStream(c.RemotePeer(), network.DirInbound) + if err != nil { + ts.Reset() + continue + } + c.swarm.refs.Add(1) + go func() { + s, err := c.addStream(ts, network.DirInbound, scope) + + // Don't defer this. We don't want to block + // swarm shutdown on the connection handler. + c.swarm.refs.Done() + + // We only get an error here when the swarm is closed or closing. + if err != nil { + return + } + + if h := c.swarm.StreamHandler(); h != nil { + h(s) + } + }() + } + }() +} + +func (c *Conn) String() string { + return fmt.Sprintf( + " %s (%s)>", + c.conn.Transport(), + c.conn.LocalMultiaddr(), + c.conn.LocalPeer().Pretty(), + c.conn.RemoteMultiaddr(), + c.conn.RemotePeer().Pretty(), + ) +} + +// LocalMultiaddr is the Multiaddr on this side +func (c *Conn) LocalMultiaddr() ma.Multiaddr { + return c.conn.LocalMultiaddr() +} + +// LocalPeer is the Peer on our side of the connection +func (c *Conn) LocalPeer() peer.ID { + return c.conn.LocalPeer() +} + +// RemoteMultiaddr is the Multiaddr on the remote side +func (c *Conn) RemoteMultiaddr() ma.Multiaddr { + return c.conn.RemoteMultiaddr() +} + +// RemotePeer is the Peer on the remote side +func (c *Conn) RemotePeer() peer.ID { + return c.conn.RemotePeer() +} + +// LocalPrivateKey is the public key of the peer on this side +func (c *Conn) LocalPrivateKey() ic.PrivKey { + return c.conn.LocalPrivateKey() +} + +// RemotePublicKey is the public key of the peer on the remote side +func (c *Conn) RemotePublicKey() ic.PubKey { + return c.conn.RemotePublicKey() +} + +// Stat returns metadata pertaining to this connection +func (c *Conn) Stat() network.ConnStats { + c.streams.Lock() + defer c.streams.Unlock() + return c.stat +} + +// NewStream returns a new Stream from this connection +func (c *Conn) NewStream(ctx context.Context) (network.Stream, error) { + if c.Stat().Transient { + if useTransient, _ := network.GetUseTransient(ctx); !useTransient { + return nil, network.ErrTransientConn + } + } + + scope, err := c.swarm.ResourceManager().OpenStream(c.RemotePeer(), network.DirOutbound) + if err != nil { + return nil, err + } + ts, err := c.conn.OpenStream(ctx) + if err != nil { + scope.Done() + return nil, err + } + return c.addStream(ts, network.DirOutbound, scope) +} + +func (c *Conn) addStream(ts network.MuxedStream, dir network.Direction, scope network.StreamManagementScope) (*Stream, error) { + c.streams.Lock() + // Are we still online? + if c.streams.m == nil { + c.streams.Unlock() + scope.Done() + ts.Reset() + return nil, ErrConnClosed + } + + // Wrap and register the stream. + s := &Stream{ + stream: ts, + conn: c, + scope: scope, + stat: network.Stats{ + Direction: dir, + Opened: time.Now(), + }, + id: atomic.AddUint64(&c.swarm.nextStreamID, 1), + } + c.stat.NumStreams++ + c.streams.m[s] = struct{}{} + + // Released once the stream disconnect notifications have finished + // firing (in Swarm.remove). + c.swarm.refs.Add(1) + + // Take the notification lock before releasing the streams lock to block + // StreamClose notifications until after the StreamOpen notifications + // done. + s.notifyLk.Lock() + c.streams.Unlock() + + c.swarm.notifyAll(func(f network.Notifiee) { + f.OpenedStream(c.swarm, s) + }) + s.notifyLk.Unlock() + + return s, nil +} + +// GetStreams returns the streams associated with this connection. +func (c *Conn) GetStreams() []network.Stream { + c.streams.Lock() + defer c.streams.Unlock() + streams := make([]network.Stream, 0, len(c.streams.m)) + for s := range c.streams.m { + streams = append(streams, s) + } + return streams +} + +func (c *Conn) Scope() network.ConnScope { + return c.conn.Scope() +} diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go new file mode 100644 index 0000000000..ca11d7326f --- /dev/null +++ b/p2p/net/swarm/swarm_dial.go @@ -0,0 +1,440 @@ +package swarm + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/transport" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +// Diagram of dial sync: +// +// many callers of Dial() synched w. dials many addrs results to callers +// ----------------------\ dialsync use earliest /-------------- +// -----------------------\ |----------\ /---------------- +// ------------------------>------------<------- >---------<----------------- +// -----------------------| \----x \---------------- +// ----------------------| \-----x \--------------- +// any may fail if no addr at end +// retry dialAttempt x + +var ( + // ErrDialBackoff is returned by the backoff code when a given peer has + // been dialed too frequently + ErrDialBackoff = errors.New("dial backoff") + + // ErrDialToSelf is returned if we attempt to dial our own peer + ErrDialToSelf = errors.New("dial to self attempted") + + // ErrNoTransport is returned when we don't know a transport for the + // given multiaddr. + ErrNoTransport = errors.New("no transport for protocol") + + // ErrAllDialsFailed is returned when connecting to a peer has ultimately failed + ErrAllDialsFailed = errors.New("all dials failed") + + // ErrNoAddresses is returned when we fail to find any addresses for a + // peer we're trying to dial. + ErrNoAddresses = errors.New("no addresses") + + // ErrNoGoodAddresses is returned when we find addresses for a peer but + // can't use any of them. + ErrNoGoodAddresses = errors.New("no good addresses") + + // ErrGaterDisallowedConnection is returned when the gater prevents us from + // forming a connection with a peer. + ErrGaterDisallowedConnection = errors.New("gater disallows connection to peer") +) + +// DialAttempts governs how many times a goroutine will try to dial a given peer. +// Note: this is down to one, as we have _too many dials_ atm. To add back in, +// add loop back in Dial(.) +const DialAttempts = 1 + +// ConcurrentFdDials is the number of concurrent outbound dials over transports +// that consume file descriptors +const ConcurrentFdDials = 160 + +// DefaultPerPeerRateLimit is the number of concurrent outbound dials to make +// per peer +const DefaultPerPeerRateLimit = 8 + +// dialbackoff is a struct used to avoid over-dialing the same, dead peers. +// Whenever we totally time out on a peer (all three attempts), we add them +// to dialbackoff. Then, whenevers goroutines would _wait_ (dialsync), they +// check dialbackoff. If it's there, they don't wait and exit promptly with +// an error. (the single goroutine that is actually dialing continues to +// dial). If a dial is successful, the peer is removed from backoff. +// Example: +// +// for { +// if ok, wait := dialsync.Lock(p); !ok { +// if backoff.Backoff(p) { +// return errDialFailed +// } +// <-wait +// continue +// } +// defer dialsync.Unlock(p) +// c, err := actuallyDial(p) +// if err != nil { +// dialbackoff.AddBackoff(p) +// continue +// } +// dialbackoff.Clear(p) +// } +// + +// DialBackoff is a type for tracking peer dial backoffs. +// +// * It's safe to use its zero value. +// * It's thread-safe. +// * It's *not* safe to move this type after using. +type DialBackoff struct { + entries map[peer.ID]map[string]*backoffAddr + lock sync.RWMutex +} + +type backoffAddr struct { + tries int + until time.Time +} + +func (db *DialBackoff) init(ctx context.Context) { + if db.entries == nil { + db.entries = make(map[peer.ID]map[string]*backoffAddr) + } + go db.background(ctx) +} + +func (db *DialBackoff) background(ctx context.Context) { + ticker := time.NewTicker(BackoffMax) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + db.cleanup() + } + } +} + +// Backoff returns whether the client should backoff from dialing +// peer p at address addr +func (db *DialBackoff) Backoff(p peer.ID, addr ma.Multiaddr) (backoff bool) { + db.lock.Lock() + defer db.lock.Unlock() + + ap, found := db.entries[p][string(addr.Bytes())] + return found && time.Now().Before(ap.until) +} + +// BackoffBase is the base amount of time to backoff (default: 5s). +var BackoffBase = time.Second * 5 + +// BackoffCoef is the backoff coefficient (default: 1s). +var BackoffCoef = time.Second + +// BackoffMax is the maximum backoff time (default: 5m). +var BackoffMax = time.Minute * 5 + +// AddBackoff lets other nodes know that we've entered backoff with +// peer p, so dialers should not wait unnecessarily. We still will +// attempt to dial with one goroutine, in case we get through. +// +// Backoff is not exponential, it's quadratic and computed according to the +// following formula: +// +// BackoffBase + BakoffCoef * PriorBackoffs^2 +// +// Where PriorBackoffs is the number of previous backoffs. +func (db *DialBackoff) AddBackoff(p peer.ID, addr ma.Multiaddr) { + saddr := string(addr.Bytes()) + db.lock.Lock() + defer db.lock.Unlock() + bp, ok := db.entries[p] + if !ok { + bp = make(map[string]*backoffAddr, 1) + db.entries[p] = bp + } + ba, ok := bp[saddr] + if !ok { + bp[saddr] = &backoffAddr{ + tries: 1, + until: time.Now().Add(BackoffBase), + } + return + } + + backoffTime := BackoffBase + BackoffCoef*time.Duration(ba.tries*ba.tries) + if backoffTime > BackoffMax { + backoffTime = BackoffMax + } + ba.until = time.Now().Add(backoffTime) + ba.tries++ +} + +// Clear removes a backoff record. Clients should call this after a +// successful Dial. +func (db *DialBackoff) Clear(p peer.ID) { + db.lock.Lock() + defer db.lock.Unlock() + delete(db.entries, p) +} + +func (db *DialBackoff) cleanup() { + db.lock.Lock() + defer db.lock.Unlock() + now := time.Now() + for p, e := range db.entries { + good := false + for _, backoff := range e { + backoffTime := BackoffBase + BackoffCoef*time.Duration(backoff.tries*backoff.tries) + if backoffTime > BackoffMax { + backoffTime = BackoffMax + } + if now.Before(backoff.until.Add(backoffTime)) { + good = true + break + } + } + if !good { + delete(db.entries, p) + } + } +} + +// DialPeer connects to a peer. +// +// The idea is that the client of Swarm does not need to know what network +// the connection will happen over. Swarm can use whichever it choses. +// This allows us to use various transport protocols, do NAT traversal/relay, +// etc. to achieve connection. +func (s *Swarm) DialPeer(ctx context.Context, p peer.ID) (network.Conn, error) { + if s.gater != nil && !s.gater.InterceptPeerDial(p) { + log.Debugf("gater disallowed outbound connection to peer %s", p.Pretty()) + return nil, &DialError{Peer: p, Cause: ErrGaterDisallowedConnection} + } + + // Avoid typed nil issues. + c, err := s.dialPeer(ctx, p) + if err != nil { + return nil, err + } + return c, nil +} + +// internal dial method that returns an unwrapped conn +// +// It is gated by the swarm's dial synchronization systems: dialsync and +// dialbackoff. +func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) { + log.Debugw("dialing peer", "from", s.local, "to", p) + err := p.Validate() + if err != nil { + return nil, err + } + + if p == s.local { + return nil, ErrDialToSelf + } + + // check if we already have an open (usable) connection first + conn := s.bestAcceptableConnToPeer(ctx, p) + if conn != nil { + return conn, nil + } + + // apply the DialPeer timeout + ctx, cancel := context.WithTimeout(ctx, network.GetDialPeerTimeout(ctx)) + defer cancel() + + conn, err = s.dsync.Dial(ctx, p) + if err == nil { + return conn, nil + } + + log.Debugf("network for %s finished dialing %s", s.local, p) + + if ctx.Err() != nil { + // Context error trumps any dial errors as it was likely the ultimate cause. + return nil, ctx.Err() + } + + if s.ctx.Err() != nil { + // Ok, so the swarm is shutting down. + return nil, ErrSwarmClosed + } + + return nil, err +} + +// dialWorkerLoop synchronizes and executes concurrent dials to a single peer +func (s *Swarm) dialWorkerLoop(p peer.ID, reqch <-chan dialRequest) { + w := newDialWorker(s, p, reqch) + w.loop() +} + +func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, error) { + peerAddrs := s.peers.Addrs(p) + if len(peerAddrs) == 0 { + return nil, ErrNoAddresses + } + + goodAddrs := s.filterKnownUndialables(p, peerAddrs) + if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect { + goodAddrs = ma.FilterAddrs(goodAddrs, s.nonProxyAddr) + } + + if len(goodAddrs) == 0 { + return nil, ErrNoGoodAddresses + } + + return goodAddrs, nil +} + +func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, resch chan dialResult) error { + // check the dial backoff + if forceDirect, _ := network.GetForceDirectDial(ctx); !forceDirect { + if s.backf.Backoff(p, addr) { + return ErrDialBackoff + } + } + + // start the dial + s.limitedDial(ctx, p, addr, resch) + + return nil +} + +func (s *Swarm) canDial(addr ma.Multiaddr) bool { + t := s.TransportForDialing(addr) + return t != nil && t.CanDial(addr) +} + +func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { + t := s.TransportForDialing(addr) + return !t.Proxy() +} + +// filterKnownUndialables takes a list of multiaddrs, and removes those +// that we definitely don't want to dial: addresses configured to be blocked, +// IPv6 link-local addresses, addresses without a dial-capable transport, +// and addresses that we know to be our own. +// This is an optimization to avoid wasting time on dials that we know are going to fail. +func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Multiaddr { + lisAddrs, _ := s.InterfaceListenAddresses() + var ourAddrs []ma.Multiaddr + for _, addr := range lisAddrs { + protos := addr.Protocols() + // we're only sure about filtering out /ip4 and /ip6 addresses, so far + if protos[0].Code == ma.P_IP4 || protos[0].Code == ma.P_IP6 { + ourAddrs = append(ourAddrs, addr) + } + } + + return ma.FilterAddrs(addrs, + func(addr ma.Multiaddr) bool { + for _, a := range ourAddrs { + if a.Equal(addr) { + return false + } + } + return true + }, + s.canDial, + // TODO: Consider allowing link-local addresses + func(addr ma.Multiaddr) bool { return !manet.IsIP6LinkLocal(addr) }, + func(addr ma.Multiaddr) bool { + return s.gater == nil || s.gater.InterceptAddrDial(p, addr) + }, + ) +} + +// limitedDial will start a dial to the given peer when +// it is able, respecting the various different types of rate +// limiting that occur without using extra goroutines per addr +func (s *Swarm) limitedDial(ctx context.Context, p peer.ID, a ma.Multiaddr, resp chan dialResult) { + timeout := s.dialTimeout + if lowTimeoutFilters.AddrBlocked(a) && s.dialTimeoutLocal < s.dialTimeout { + timeout = s.dialTimeoutLocal + } + s.limiter.AddDialJob(&dialJob{ + addr: a, + peer: p, + resp: resp, + ctx: ctx, + timeout: timeout, + }) +} + +// dialAddr is the actual dial for an addr, indirectly invoked through the limiter +func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr) (transport.CapableConn, error) { + // Just to double check. Costs nothing. + if s.local == p { + return nil, ErrDialToSelf + } + log.Debugf("%s swarm dialing %s %s", s.local, p, addr) + + tpt := s.TransportForDialing(addr) + if tpt == nil { + return nil, ErrNoTransport + } + + connC, err := tpt.Dial(ctx, addr, p) + if err != nil { + return nil, err + } + + // Trust the transport? Yeah... right. + if connC.RemotePeer() != p { + connC.Close() + err = fmt.Errorf("BUG in transport %T: tried to dial %s, dialed %s", p, connC.RemotePeer(), tpt) + log.Error(err) + return nil, err + } + + // success! we got one! + return connC, nil +} + +// TODO We should have a `IsFdConsuming() bool` method on the `Transport` interface in go-libp2p-core/transport. +// This function checks if any of the transport protocols in the address requires a file descriptor. +// For now: +// A Non-circuit address which has the TCP/UNIX protocol is deemed FD consuming. +// For a circuit-relay address, we look at the address of the relay server/proxy +// and use the same logic as above to decide. +func isFdConsumingAddr(addr ma.Multiaddr) bool { + first, _ := ma.SplitFunc(addr, func(c ma.Component) bool { + return c.Protocol().Code == ma.P_CIRCUIT + }) + + // for safety + if first == nil { + return true + } + + _, err1 := first.ValueForProtocol(ma.P_TCP) + _, err2 := first.ValueForProtocol(ma.P_UNIX) + return err1 == nil || err2 == nil +} + +func isExpensiveAddr(addr ma.Multiaddr) bool { + _, err1 := addr.ValueForProtocol(ma.P_WS) + _, err2 := addr.ValueForProtocol(ma.P_WSS) + return err1 == nil || err2 == nil +} + +func isRelayAddr(addr ma.Multiaddr) bool { + _, err := addr.ValueForProtocol(ma.P_CIRCUIT) + return err == nil +} diff --git a/p2p/net/swarm/swarm_listen.go b/p2p/net/swarm/swarm_listen.go new file mode 100644 index 0000000000..ca54280c08 --- /dev/null +++ b/p2p/net/swarm/swarm_listen.go @@ -0,0 +1,121 @@ +package swarm + +import ( + "fmt" + "time" + + "github.com/libp2p/go-libp2p-core/network" + + ma "github.com/multiformats/go-multiaddr" +) + +// Listen sets up listeners for all of the given addresses. +// It returns as long as we successfully listen on at least *one* address. +func (s *Swarm) Listen(addrs ...ma.Multiaddr) error { + errs := make([]error, len(addrs)) + var succeeded int + for i, a := range addrs { + if err := s.AddListenAddr(a); err != nil { + errs[i] = err + } else { + succeeded++ + } + } + + for i, e := range errs { + if e != nil { + log.Warnw("listening failed", "on", addrs[i], "error", errs[i]) + } + } + + if succeeded == 0 && len(addrs) > 0 { + return fmt.Errorf("failed to listen on any addresses: %s", errs) + } + + return nil +} + +// AddListenAddr tells the swarm to listen on a single address. Unlike Listen, +// this method does not attempt to filter out bad addresses. +func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { + tpt := s.TransportForListening(a) + if tpt == nil { + // TransportForListening will return nil if either: + // 1. No transport has been registered. + // 2. We're closed (so we've nulled out the transport map. + // + // Distinguish between these two cases to avoid confusing users. + select { + case <-s.ctx.Done(): + return ErrSwarmClosed + default: + return ErrNoTransport + } + } + + list, err := tpt.Listen(a) + if err != nil { + return err + } + + s.listeners.Lock() + if s.listeners.m == nil { + s.listeners.Unlock() + list.Close() + return ErrSwarmClosed + } + s.refs.Add(1) + s.listeners.m[list] = struct{}{} + s.listeners.cacheEOL = time.Time{} + s.listeners.Unlock() + + maddr := list.Multiaddr() + + // signal to our notifiees on listen. + s.notifyAll(func(n network.Notifiee) { + n.Listen(s, maddr) + }) + + go func() { + defer func() { + list.Close() + s.listeners.Lock() + delete(s.listeners.m, list) + s.listeners.cacheEOL = time.Time{} + s.listeners.Unlock() + + // signal to our notifiees on listen close. + s.notifyAll(func(n network.Notifiee) { + n.ListenClose(s, maddr) + }) + s.refs.Done() + }() + for { + c, err := list.Accept() + if err != nil { + if s.ctx.Err() == nil { + // only log if the swarm is still running. + log.Errorf("swarm listener accept error: %s", err) + } + return + } + + log.Debugf("swarm listener accepted connection: %s", c) + s.refs.Add(1) + go func() { + defer s.refs.Done() + _, err := s.addConn(c, network.DirInbound) + switch err { + case nil: + case ErrSwarmClosed: + // ignore. + return + default: + log.Warnw("adding connection failed", "to", a, "error", err) + return + } + }() + } + }() + return nil +} diff --git a/p2p/net/swarm/swarm_net_test.go b/p2p/net/swarm/swarm_net_test.go new file mode 100644 index 0000000000..1fa08bb7ff --- /dev/null +++ b/p2p/net/swarm/swarm_net_test.go @@ -0,0 +1,164 @@ +package swarm_test + +import ( + "context" + "fmt" + "io/ioutil" + "testing" + "time" + + . "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + + "github.com/libp2p/go-libp2p-core/network" + + "github.com/stretchr/testify/require" +) + +// TestConnectednessCorrect starts a few networks, connects a few +// and tests Connectedness value is correct. +func TestConnectednessCorrect(t *testing.T) { + nets := make([]network.Network, 4) + for i := 0; i < 4; i++ { + nets[i] = GenSwarm(t) + } + + // connect 0-1, 0-2, 0-3, 1-2, 2-3 + + dial := func(a, b network.Network) { + DivulgeAddresses(b, a) + if _, err := a.DialPeer(context.Background(), b.LocalPeer()); err != nil { + t.Fatalf("Failed to dial: %s", err) + } + } + + dial(nets[0], nets[1]) + dial(nets[0], nets[3]) + dial(nets[1], nets[2]) + dial(nets[3], nets[2]) + + // The notifications for new connections get sent out asynchronously. + // There is the potential for a race condition here, so we sleep to ensure + // that they have been received. + time.Sleep(time.Millisecond * 100) + + // test those connected show up correctly + + // test connected + expectConnectedness(t, nets[0], nets[1], network.Connected) + expectConnectedness(t, nets[0], nets[3], network.Connected) + expectConnectedness(t, nets[1], nets[2], network.Connected) + expectConnectedness(t, nets[3], nets[2], network.Connected) + + // test not connected + expectConnectedness(t, nets[0], nets[2], network.NotConnected) + expectConnectedness(t, nets[1], nets[3], network.NotConnected) + + require.Len(t, nets[0].Peers(), 2, "expected net 0 to have two peers") + require.Len(t, nets[2].Peers(), 2, "expected net 2 to have two peers") + require.NotZerof(t, nets[1].ConnsToPeer(nets[3].LocalPeer()), "net 1 should have no connections to net 3") + require.NoError(t, nets[2].ClosePeer(nets[1].LocalPeer())) + + time.Sleep(time.Millisecond * 50) + expectConnectedness(t, nets[2], nets[1], network.NotConnected) + + for _, n := range nets { + n.Close() + } +} + +func expectConnectedness(t *testing.T, a, b network.Network, expected network.Connectedness) { + es := "%s is connected to %s, but Connectedness incorrect. %s %s %s" + atob := a.Connectedness(b.LocalPeer()) + btoa := b.Connectedness(a.LocalPeer()) + if atob != expected { + t.Errorf(es, a, b, printConns(a), printConns(b), atob) + } + + // test symmetric case + if btoa != expected { + t.Errorf(es, b, a, printConns(b), printConns(a), btoa) + } +} + +func printConns(n network.Network) string { + s := fmt.Sprintf("Connections in %s:\n", n) + for _, c := range n.Conns() { + s = s + fmt.Sprintf("- %s\n", c) + } + return s +} + +func TestNetworkOpenStream(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testString := "hello ipfs" + + nets := make([]network.Network, 4) + for i := 0; i < 4; i++ { + nets[i] = GenSwarm(t) + } + + dial := func(a, b network.Network) { + DivulgeAddresses(b, a) + if _, err := a.DialPeer(ctx, b.LocalPeer()); err != nil { + t.Fatalf("Failed to dial: %s", err) + } + } + + dial(nets[0], nets[1]) + dial(nets[0], nets[3]) + dial(nets[1], nets[2]) + + done := make(chan bool) + nets[1].SetStreamHandler(func(s network.Stream) { + defer close(done) + defer s.Close() + + buf, err := ioutil.ReadAll(s) + if err != nil { + t.Error(err) + return + } + if string(buf) != testString { + t.Error("got wrong message") + } + }) + + s, err := nets[0].NewStream(ctx, nets[1].LocalPeer()) + if err != nil { + t.Fatal(err) + } + + var numStreams int + for _, conn := range nets[0].ConnsToPeer(nets[1].LocalPeer()) { + numStreams += conn.Stat().NumStreams + } + + if numStreams != 1 { + t.Fatal("should only have one stream there") + } + + n, err := s.Write([]byte(testString)) + if err != nil { + t.Fatal(err) + } else if n != len(testString) { + t.Errorf("expected to write %d bytes, wrote %d", len(testString), n) + } + + err = s.Close() + if err != nil { + t.Fatal(err) + } + + select { + case <-done: + case <-time.After(time.Millisecond * 100): + t.Fatal("timed out waiting on stream") + } + + _, err = nets[1].NewStream(ctx, nets[3].LocalPeer()) + if err == nil { + t.Fatal("expected stream open 1->3 to fail") + } +} diff --git a/p2p/net/swarm/swarm_notif_test.go b/p2p/net/swarm/swarm_notif_test.go new file mode 100644 index 0000000000..8b1011bee7 --- /dev/null +++ b/p2p/net/swarm/swarm_notif_test.go @@ -0,0 +1,226 @@ +package swarm_test + +import ( + "context" + "testing" + "time" + + . "github.com/libp2p/go-libp2p/p2p/net/swarm" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + + ma "github.com/multiformats/go-multiaddr" + + "github.com/stretchr/testify/require" +) + +func TestNotifications(t *testing.T) { + const swarmSize = 5 + + notifiees := make([]*netNotifiee, swarmSize) + + swarms := makeSwarms(t, swarmSize) + defer func() { + for i, s := range swarms { + select { + case <-notifiees[i].listenClose: + t.Error("should not have been closed") + default: + } + require.NoError(t, s.Close()) + select { + case <-notifiees[i].listenClose: + default: + t.Error("expected a listen close notification") + } + } + }() + + const timeout = 5 * time.Second + + // signup notifs + for i, swarm := range swarms { + n := newNetNotifiee(swarmSize) + swarm.Notify(n) + notifiees[i] = n + } + + connectSwarms(t, context.Background(), swarms) + + time.Sleep(50 * time.Millisecond) + // should've gotten 5 by now. + + // test everyone got the correct connection opened calls + for i, s := range swarms { + n := notifiees[i] + notifs := make(map[peer.ID][]network.Conn) + for j, s2 := range swarms { + if i == j { + continue + } + + // this feels a little sketchy, but its probably okay + for len(s.ConnsToPeer(s2.LocalPeer())) != len(notifs[s2.LocalPeer()]) { + select { + case c := <-n.connected: + nfp := notifs[c.RemotePeer()] + notifs[c.RemotePeer()] = append(nfp, c) + case <-time.After(timeout): + t.Fatal("timeout") + } + } + } + + for p, cons := range notifs { + expect := s.ConnsToPeer(p) + if len(expect) != len(cons) { + t.Fatal("got different number of connections") + } + + for _, c := range cons { + var found bool + for _, c2 := range expect { + if c == c2 { + found = true + break + } + } + + if !found { + t.Fatal("connection not found!") + } + } + } + } + + complement := func(c network.Conn) (*Swarm, *netNotifiee, *Conn) { + for i, s := range swarms { + for _, c2 := range s.Conns() { + if c.LocalMultiaddr().Equal(c2.RemoteMultiaddr()) && + c2.LocalMultiaddr().Equal(c.RemoteMultiaddr()) { + return s, notifiees[i], c2.(*Conn) + } + } + } + t.Fatal("complementary conn not found", c) + return nil, nil, nil + } + + testOCStream := func(n *netNotifiee, s network.Stream) { + var s2 network.Stream + select { + case s2 = <-n.openedStream: + t.Log("got notif for opened stream") + case <-time.After(timeout): + t.Fatal("timeout") + } + if s != s2 { + t.Fatal("got incorrect stream", s.Conn(), s2.Conn()) + } + + select { + case s2 = <-n.closedStream: + t.Log("got notif for closed stream") + case <-time.After(timeout): + t.Fatal("timeout") + } + if s != s2 { + t.Fatal("got incorrect stream", s.Conn(), s2.Conn()) + } + } + + streams := make(chan network.Stream) + for _, s := range swarms { + s.SetStreamHandler(func(s network.Stream) { + streams <- s + s.Reset() + }) + } + + // open a streams in each conn + for i, s := range swarms { + for _, c := range s.Conns() { + _, n2, _ := complement(c) + + st1, err := c.NewStream(context.Background()) + if err != nil { + t.Error(err) + } else { + st1.Write([]byte("hello")) + st1.Reset() + testOCStream(notifiees[i], st1) + st2 := <-streams + testOCStream(n2, st2) + } + } + } + + // close conns + for i, s := range swarms { + n := notifiees[i] + for _, c := range s.Conns() { + _, n2, c2 := complement(c) + c.Close() + c2.Close() + + var c3, c4 network.Conn + select { + case c3 = <-n.disconnected: + case <-time.After(timeout): + t.Fatal("timeout") + } + if c != c3 { + t.Fatal("got incorrect conn", c, c3) + } + + select { + case c4 = <-n2.disconnected: + case <-time.After(timeout): + t.Fatal("timeout") + } + if c2 != c4 { + t.Fatal("got incorrect conn", c, c2) + } + } + } +} + +type netNotifiee struct { + listen chan ma.Multiaddr + listenClose chan ma.Multiaddr + connected chan network.Conn + disconnected chan network.Conn + openedStream chan network.Stream + closedStream chan network.Stream +} + +func newNetNotifiee(buffer int) *netNotifiee { + return &netNotifiee{ + listen: make(chan ma.Multiaddr, buffer), + listenClose: make(chan ma.Multiaddr, buffer), + connected: make(chan network.Conn, buffer), + disconnected: make(chan network.Conn, buffer), + openedStream: make(chan network.Stream, buffer), + closedStream: make(chan network.Stream, buffer), + } +} + +func (nn *netNotifiee) Listen(n network.Network, a ma.Multiaddr) { + nn.listen <- a +} +func (nn *netNotifiee) ListenClose(n network.Network, a ma.Multiaddr) { + nn.listenClose <- a +} +func (nn *netNotifiee) Connected(n network.Network, v network.Conn) { + nn.connected <- v +} +func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) { + nn.disconnected <- v +} +func (nn *netNotifiee) OpenedStream(n network.Network, v network.Stream) { + nn.openedStream <- v +} +func (nn *netNotifiee) ClosedStream(n network.Network, v network.Stream) { + nn.closedStream <- v +} diff --git a/p2p/net/swarm/swarm_stream.go b/p2p/net/swarm/swarm_stream.go new file mode 100644 index 0000000000..5e5c965335 --- /dev/null +++ b/p2p/net/swarm/swarm_stream.go @@ -0,0 +1,165 @@ +package swarm + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/protocol" +) + +// Validate Stream conforms to the go-libp2p-net Stream interface +var _ network.Stream = &Stream{} + +// Stream is the stream type used by swarm. In general, you won't use this type +// directly. +type Stream struct { + id uint64 + + stream network.MuxedStream + conn *Conn + scope network.StreamManagementScope + + closeOnce sync.Once + + notifyLk sync.Mutex + + protocol atomic.Value + + stat network.Stats +} + +func (s *Stream) ID() string { + // format: -- + return fmt.Sprintf("%s-%d", s.conn.ID(), s.id) +} + +func (s *Stream) String() string { + return fmt.Sprintf( + " %s (%s)>", + s.conn.conn.Transport(), + s.conn.LocalMultiaddr(), + s.conn.LocalPeer(), + s.conn.RemoteMultiaddr(), + s.conn.RemotePeer(), + ) +} + +// Conn returns the Conn associated with this stream, as an network.Conn +func (s *Stream) Conn() network.Conn { + return s.conn +} + +// Read reads bytes from a stream. +func (s *Stream) Read(p []byte) (int, error) { + n, err := s.stream.Read(p) + // TODO: push this down to a lower level for better accuracy. + if s.conn.swarm.bwc != nil { + s.conn.swarm.bwc.LogRecvMessage(int64(n)) + s.conn.swarm.bwc.LogRecvMessageStream(int64(n), s.Protocol(), s.Conn().RemotePeer()) + } + return n, err +} + +// Write writes bytes to a stream, flushing for each call. +func (s *Stream) Write(p []byte) (int, error) { + n, err := s.stream.Write(p) + // TODO: push this down to a lower level for better accuracy. + if s.conn.swarm.bwc != nil { + s.conn.swarm.bwc.LogSentMessage(int64(n)) + s.conn.swarm.bwc.LogSentMessageStream(int64(n), s.Protocol(), s.Conn().RemotePeer()) + } + return n, err +} + +// Close closes the stream, closing both ends and freeing all associated +// resources. +func (s *Stream) Close() error { + err := s.stream.Close() + s.closeOnce.Do(s.remove) + return err +} + +// Reset resets the stream, signaling an error on both ends and freeing all +// associated resources. +func (s *Stream) Reset() error { + err := s.stream.Reset() + s.closeOnce.Do(s.remove) + return err +} + +// Close closes the stream for writing, flushing all data and sending an EOF. +// This function does not free resources, call Close or Reset when done with the +// stream. +func (s *Stream) CloseWrite() error { + return s.stream.CloseWrite() +} + +// Close closes the stream for reading. This function does not free resources, +// call Close or Reset when done with the stream. +func (s *Stream) CloseRead() error { + return s.stream.CloseRead() +} + +func (s *Stream) remove() { + s.conn.removeStream(s) + + // We *must* do this in a goroutine. This can be called during a + // an open notification and will block until that notification is done. + go func() { + s.notifyLk.Lock() + defer s.notifyLk.Unlock() + + s.conn.swarm.notifyAll(func(f network.Notifiee) { + f.ClosedStream(s.conn.swarm, s) + }) + s.conn.swarm.refs.Done() + }() +} + +// Protocol returns the protocol negotiated on this stream (if set). +func (s *Stream) Protocol() protocol.ID { + // Ignore type error. It means that the protocol is unset. + p, _ := s.protocol.Load().(protocol.ID) + return p +} + +// SetProtocol sets the protocol for this stream. +// +// This doesn't actually *do* anything other than record the fact that we're +// speaking the given protocol over this stream. It's still up to the user to +// negotiate the protocol. This is usually done by the Host. +func (s *Stream) SetProtocol(p protocol.ID) error { + if err := s.scope.SetProtocol(p); err != nil { + return err + } + + s.protocol.Store(p) + return nil +} + +// SetDeadline sets the read and write deadlines for this stream. +func (s *Stream) SetDeadline(t time.Time) error { + return s.stream.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline for this stream. +func (s *Stream) SetReadDeadline(t time.Time) error { + return s.stream.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline for this stream. +func (s *Stream) SetWriteDeadline(t time.Time) error { + return s.stream.SetWriteDeadline(t) +} + +// Stat returns metadata information for this stream. +func (s *Stream) Stat() network.Stats { + return s.stat +} + +func (s *Stream) Scope() network.StreamScope { + return s.scope +} diff --git a/p2p/net/swarm/swarm_test.go b/p2p/net/swarm/swarm_test.go new file mode 100644 index 0000000000..c7db03e994 --- /dev/null +++ b/p2p/net/swarm/swarm_test.go @@ -0,0 +1,541 @@ +package swarm_test + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "strings" + "sync" + "testing" + "time" + + "github.com/libp2p/go-libp2p-core/protocol" + + "github.com/libp2p/go-libp2p/p2p/net/swarm" + . "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + + "github.com/libp2p/go-libp2p-core/control" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peerstore" + + logging "github.com/ipfs/go-log/v2" + mocknetwork "github.com/libp2p/go-libp2p-testing/mocks/network" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" +) + +var log = logging.Logger("swarm_test") + +func EchoStreamHandler(stream network.Stream) { + go func() { + defer stream.Close() + + // pull out the ipfs conn + c := stream.Conn() + log.Infof("%s ponging to %s", c.LocalPeer(), c.RemotePeer()) + + buf := make([]byte, 4) + + for { + if _, err := stream.Read(buf); err != nil { + if err != io.EOF { + log.Error("ping receive error:", err) + } + return + } + + if !bytes.Equal(buf, []byte("ping")) { + log.Errorf("ping receive error: ping != %s %v", buf, buf) + return + } + + if _, err := stream.Write([]byte("pong")); err != nil { + log.Error("pond send error:", err) + return + } + } + }() +} + +func makeDialOnlySwarm(t *testing.T) *swarm.Swarm { + swarm := GenSwarm(t, OptDialOnly) + swarm.SetStreamHandler(EchoStreamHandler) + return swarm +} + +func makeSwarms(t *testing.T, num int, opts ...Option) []*swarm.Swarm { + swarms := make([]*swarm.Swarm, 0, num) + for i := 0; i < num; i++ { + swarm := GenSwarm(t, opts...) + swarm.SetStreamHandler(EchoStreamHandler) + swarms = append(swarms, swarm) + } + return swarms +} + +func connectSwarms(t *testing.T, ctx context.Context, swarms []*swarm.Swarm) { + var wg sync.WaitGroup + connect := func(s *swarm.Swarm, dst peer.ID, addr ma.Multiaddr) { + // TODO: make a DialAddr func. + s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) + if _, err := s.DialPeer(ctx, dst); err != nil { + t.Fatal("error swarm dialing to peer", err) + } + wg.Done() + } + + log.Info("Connecting swarms simultaneously.") + for i, s1 := range swarms { + for _, s2 := range swarms[i+1:] { + wg.Add(1) + connect(s1, s2.LocalPeer(), s2.ListenAddresses()[0]) // try the first. + } + } + wg.Wait() + + for _, s := range swarms { + log.Infof("%s swarm routing table: %s", s.LocalPeer(), s.Peers()) + } +} + +func subtestSwarm(t *testing.T, SwarmNum int, MsgNum int) { + swarms := makeSwarms(t, SwarmNum, OptDisableReuseport) + + // connect everyone + connectSwarms(t, context.Background(), swarms) + + // ping/pong + for _, s1 := range swarms { + log.Debugf("-------------------------------------------------------") + log.Debugf("%s ping pong round", s1.LocalPeer()) + log.Debugf("-------------------------------------------------------") + + _, cancel := context.WithCancel(context.Background()) + got := map[peer.ID]int{} + errChan := make(chan error, MsgNum*len(swarms)) + streamChan := make(chan network.Stream, MsgNum) + + // send out "ping" x MsgNum to every peer + go func() { + defer close(streamChan) + + var wg sync.WaitGroup + send := func(p peer.ID) { + defer wg.Done() + + // first, one stream per peer (nice) + stream, err := s1.NewStream(context.Background(), p) + if err != nil { + errChan <- err + return + } + + // send out ping! + for k := 0; k < MsgNum; k++ { // with k messages + msg := "ping" + log.Debugf("%s %s %s (%d)", s1.LocalPeer(), msg, p, k) + if _, err := stream.Write([]byte(msg)); err != nil { + errChan <- err + continue + } + } + + // read it later + streamChan <- stream + } + + for _, s2 := range swarms { + if s2.LocalPeer() == s1.LocalPeer() { + continue // dont send to self... + } + + wg.Add(1) + go send(s2.LocalPeer()) + } + wg.Wait() + }() + + // receive "pong" x MsgNum from every peer + go func() { + defer close(errChan) + count := 0 + countShouldBe := MsgNum * (len(swarms) - 1) + for stream := range streamChan { // one per peer + // get peer on the other side + p := stream.Conn().RemotePeer() + + // receive pings + msgCount := 0 + msg := make([]byte, 4) + for k := 0; k < MsgNum; k++ { // with k messages + + // read from the stream + if _, err := stream.Read(msg); err != nil { + errChan <- err + continue + } + + if string(msg) != "pong" { + errChan <- fmt.Errorf("unexpected message: %s", msg) + continue + } + + log.Debugf("%s %s %s (%d)", s1.LocalPeer(), msg, p, k) + msgCount++ + } + + got[p] = msgCount + count += msgCount + stream.Close() + } + + if count != countShouldBe { + errChan <- fmt.Errorf("count mismatch: %d != %d", count, countShouldBe) + } + }() + + // check any errors (blocks till consumer is done) + for err := range errChan { + if err != nil { + t.Error(err.Error()) + } + } + + log.Debugf("%s got pongs", s1.LocalPeer()) + if (len(swarms) - 1) != len(got) { + t.Errorf("got (%d) less messages than sent (%d).", len(got), len(swarms)) + } + + for p, n := range got { + if n != MsgNum { + t.Error("peer did not get all msgs", p, n, "/", MsgNum) + } + } + + cancel() + <-time.After(10 * time.Millisecond) + } +} + +func TestSwarm(t *testing.T) { + t.Parallel() + subtestSwarm(t, 5, 100) +} + +func TestBasicSwarm(t *testing.T) { + // t.Skip("skipping for another test") + t.Parallel() + subtestSwarm(t, 2, 1) +} + +func TestConnectionGating(t *testing.T) { + ctx := context.Background() + tcs := map[string]struct { + p1Gater func(gater *MockConnectionGater) *MockConnectionGater + p2Gater func(gater *MockConnectionGater) *MockConnectionGater + + p1ConnectednessToP2 network.Connectedness + p2ConnectednessToP1 network.Connectedness + isP1OutboundErr bool + disableOnQUIC bool + }{ + "no gating": { + p1ConnectednessToP2: network.Connected, + p2ConnectednessToP1: network.Connected, + isP1OutboundErr: false, + }, + "p1 gates outbound peer dial": { + p1Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.PeerDial = func(p peer.ID) bool { return false } + return c + }, + p1ConnectednessToP2: network.NotConnected, + p2ConnectednessToP1: network.NotConnected, + isP1OutboundErr: true, + }, + "p1 gates outbound addr dialing": { + p1Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.Dial = func(p peer.ID, addr ma.Multiaddr) bool { return false } + return c + }, + p1ConnectednessToP2: network.NotConnected, + p2ConnectednessToP1: network.NotConnected, + isP1OutboundErr: true, + }, + "p2 accepts inbound peer dial if outgoing dial is gated": { + p2Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.Dial = func(peer.ID, ma.Multiaddr) bool { return false } + return c + }, + p1ConnectednessToP2: network.Connected, + p2ConnectednessToP1: network.Connected, + isP1OutboundErr: false, + }, + "p2 gates inbound peer dial before securing": { + p2Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.Accept = func(c network.ConnMultiaddrs) bool { return false } + return c + }, + p1ConnectednessToP2: network.NotConnected, + p2ConnectednessToP1: network.NotConnected, + isP1OutboundErr: true, + // QUIC gates the connection after completion of the handshake + disableOnQUIC: true, + }, + "p2 gates inbound peer dial before multiplexing": { + p1Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.Secured = func(network.Direction, peer.ID, network.ConnMultiaddrs) bool { return false } + return c + }, + p1ConnectednessToP2: network.NotConnected, + p2ConnectednessToP1: network.NotConnected, + isP1OutboundErr: true, + }, + "p2 gates inbound peer dial after upgrading": { + p1Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.Upgraded = func(c network.Conn) (bool, control.DisconnectReason) { return false, 0 } + return c + }, + p1ConnectednessToP2: network.NotConnected, + p2ConnectednessToP1: network.NotConnected, + isP1OutboundErr: true, + }, + "p2 gates outbound dials": { + p2Gater: func(c *MockConnectionGater) *MockConnectionGater { + c.PeerDial = func(p peer.ID) bool { return false } + return c + }, + p1ConnectednessToP2: network.Connected, + p2ConnectednessToP1: network.Connected, + isP1OutboundErr: false, + }, + } + + for n, tc := range tcs { + for _, useQuic := range []bool{false, true} { + trString := "TCP" + optTransport := OptDisableQUIC + if useQuic { + if tc.disableOnQUIC { + continue + } + trString = "QUIC" + optTransport = OptDisableTCP + } + t.Run(fmt.Sprintf("%s %s", n, trString), func(t *testing.T) { + p1Gater := DefaultMockConnectionGater() + p2Gater := DefaultMockConnectionGater() + if tc.p1Gater != nil { + p1Gater = tc.p1Gater(p1Gater) + } + if tc.p2Gater != nil { + p2Gater = tc.p2Gater(p2Gater) + } + + sw1 := GenSwarm(t, OptConnGater(p1Gater), optTransport) + sw2 := GenSwarm(t, OptConnGater(p2Gater), optTransport) + + p1 := sw1.LocalPeer() + p2 := sw2.LocalPeer() + sw1.Peerstore().AddAddr(p2, sw2.ListenAddresses()[0], peerstore.PermanentAddrTTL) + // 1 -> 2 + _, err := sw1.DialPeer(ctx, p2) + + require.Equal(t, tc.isP1OutboundErr, err != nil, n) + require.Equal(t, tc.p1ConnectednessToP2, sw1.Connectedness(p2), n) + + require.Eventually(t, func() bool { + return tc.p2ConnectednessToP1 == sw2.Connectedness(p1) + }, 2*time.Second, 100*time.Millisecond, n) + }) + } + } +} + +func TestNoDial(t *testing.T) { + swarms := makeSwarms(t, 2) + + _, err := swarms[0].NewStream(network.WithNoDial(context.Background(), "swarm test"), swarms[1].LocalPeer()) + if err != network.ErrNoConn { + t.Fatal("should have failed with ErrNoConn") + } +} + +func TestCloseWithOpenStreams(t *testing.T) { + ctx := context.Background() + swarms := makeSwarms(t, 2) + connectSwarms(t, ctx, swarms) + + s, err := swarms[0].NewStream(ctx, swarms[1].LocalPeer()) + require.NoError(t, err) + defer s.Close() + // close swarm before stream. + require.NoError(t, swarms[0].Close()) +} + +func TestTypedNilConn(t *testing.T) { + s := GenSwarm(t) + defer s.Close() + + // We can't dial ourselves. + c, err := s.DialPeer(context.Background(), s.LocalPeer()) + require.Error(t, err) + // If we fail to dial, the connection should be nil. + require.Nil(t, c) +} + +func TestPreventDialListenAddr(t *testing.T) { + s := GenSwarm(t, OptDialOnly) + if err := s.Listen(ma.StringCast("/ip4/0.0.0.0/udp/0/quic")); err != nil { + t.Fatal(err) + } + addrs, err := s.InterfaceListenAddresses() + if err != nil { + t.Fatal(err) + } + var addr ma.Multiaddr + for _, a := range addrs { + _, s, err := manet.DialArgs(a) + if err != nil { + t.Fatal(err) + } + if strings.Split(s, ":")[0] == "127.0.0.1" { + addr = a + break + } + } + remote := peer.ID("foobar") + s.Peerstore().AddAddr(remote, addr, time.Hour) + _, err = s.DialPeer(context.Background(), remote) + if !errors.Is(err, swarm.ErrNoGoodAddresses) { + t.Fatal("expected dial to fail: %w", err) + } +} + +func TestStreamCount(t *testing.T) { + s1 := GenSwarm(t) + s2 := GenSwarm(t) + connectSwarms(t, context.Background(), []*swarm.Swarm{s2, s1}) + + countStreams := func() (n int) { + var num int + for _, c := range s1.ConnsToPeer(s2.LocalPeer()) { + n += c.Stat().NumStreams + num += len(c.GetStreams()) + } + require.Equal(t, n, num, "inconsistent stream count") + return + } + + streams := make(chan network.Stream, 20) + streamAccepted := make(chan struct{}, 1) + s1.SetStreamHandler(func(str network.Stream) { + streams <- str + streamAccepted <- struct{}{} + }) + + for i := 0; i < 10; i++ { + str, err := s2.NewStream(context.Background(), s1.LocalPeer()) + require.NoError(t, err) + str.Write([]byte("foobar")) + <-streamAccepted + } + require.Eventually(t, func() bool { return len(streams) == 10 }, 5*time.Second, 10*time.Millisecond) + require.Equal(t, countStreams(), 10) + (<-streams).Reset() + (<-streams).Close() + require.Equal(t, countStreams(), 8) + + str, err := s1.NewStream(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + require.Equal(t, countStreams(), 9) + str.Close() + require.Equal(t, countStreams(), 8) +} + +func TestResourceManager(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + rcmgr1 := mocknetwork.NewMockResourceManager(ctrl) + s1 := GenSwarm(t, OptResourceManager(rcmgr1)) + defer s1.Close() + + rcmgr2 := mocknetwork.NewMockResourceManager(ctrl) + s2 := GenSwarm(t, OptResourceManager(rcmgr2)) + defer s2.Close() + connectSwarms(t, context.Background(), []*swarm.Swarm{s1, s2}) + + strChan := make(chan network.Stream) + s2.SetStreamHandler(func(str network.Stream) { strChan <- str }) + + streamScope1 := mocknetwork.NewMockStreamManagementScope(ctrl) + rcmgr1.EXPECT().OpenStream(s2.LocalPeer(), network.DirOutbound).Return(streamScope1, nil) + streamScope2 := mocknetwork.NewMockStreamManagementScope(ctrl) + rcmgr2.EXPECT().OpenStream(s1.LocalPeer(), network.DirInbound).Return(streamScope2, nil) + str, err := s1.NewStream(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + str.Write([]byte("foobar")) + + p := protocol.ID("proto") + streamScope1.EXPECT().SetProtocol(p) + require.NoError(t, str.SetProtocol(p)) + + sstr := <-strChan + streamScope2.EXPECT().Done() + require.NoError(t, sstr.Close()) + streamScope1.EXPECT().Done() + require.NoError(t, str.Close()) +} + +func TestResourceManagerNewStream(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + rcmgr1 := mocknetwork.NewMockResourceManager(ctrl) + s1 := GenSwarm(t, OptResourceManager(rcmgr1)) + defer s1.Close() + + s2 := GenSwarm(t) + defer s2.Close() + + connectSwarms(t, context.Background(), []*swarm.Swarm{s1, s2}) + + rerr := errors.New("denied") + rcmgr1.EXPECT().OpenStream(s2.LocalPeer(), network.DirOutbound).Return(nil, rerr) + _, err := s1.NewStream(context.Background(), s2.LocalPeer()) + require.ErrorIs(t, err, rerr) +} + +func TestResourceManagerAcceptStream(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + rcmgr1 := mocknetwork.NewMockResourceManager(ctrl) + s1 := GenSwarm(t, OptResourceManager(rcmgr1)) + defer s1.Close() + + rcmgr2 := mocknetwork.NewMockResourceManager(ctrl) + s2 := GenSwarm(t, OptResourceManager(rcmgr2)) + defer s2.Close() + s2.SetStreamHandler(func(str network.Stream) { t.Fatal("didn't expect to accept a stream") }) + + connectSwarms(t, context.Background(), []*swarm.Swarm{s1, s2}) + + streamScope := mocknetwork.NewMockStreamManagementScope(ctrl) + rcmgr1.EXPECT().OpenStream(s2.LocalPeer(), network.DirOutbound).Return(streamScope, nil) + streamScope.EXPECT().Done() + rcmgr2.EXPECT().OpenStream(s1.LocalPeer(), network.DirInbound).Return(nil, errors.New("nope")) + str, err := s1.NewStream(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + _, err = str.Write([]byte("foobar")) + require.NoError(t, err) + _, err = str.Read([]byte{0}) + require.EqualError(t, err, "stream reset") +} diff --git a/p2p/net/swarm/swarm_transport.go b/p2p/net/swarm/swarm_transport.go new file mode 100644 index 0000000000..21728ac3b5 --- /dev/null +++ b/p2p/net/swarm/swarm_transport.go @@ -0,0 +1,111 @@ +package swarm + +import ( + "fmt" + "strings" + + "github.com/libp2p/go-libp2p-core/transport" + + ma "github.com/multiformats/go-multiaddr" +) + +// TransportForDialing retrieves the appropriate transport for dialing the given +// multiaddr. +func (s *Swarm) TransportForDialing(a ma.Multiaddr) transport.Transport { + protocols := a.Protocols() + if len(protocols) == 0 { + return nil + } + + s.transports.RLock() + defer s.transports.RUnlock() + if len(s.transports.m) == 0 { + // make sure we're not just shutting down. + if s.transports.m != nil { + log.Error("you have no transports configured") + } + return nil + } + + for _, p := range protocols { + transport, ok := s.transports.m[p.Code] + if !ok { + continue + } + if transport.Proxy() { + return transport + } + } + + return s.transports.m[protocols[len(protocols)-1].Code] +} + +// TransportForListening retrieves the appropriate transport for listening on +// the given multiaddr. +func (s *Swarm) TransportForListening(a ma.Multiaddr) transport.Transport { + protocols := a.Protocols() + if len(protocols) == 0 { + return nil + } + + s.transports.RLock() + defer s.transports.RUnlock() + if len(s.transports.m) == 0 { + // make sure we're not just shutting down. + if s.transports.m != nil { + log.Error("you have no transports configured") + } + return nil + } + + selected := s.transports.m[protocols[len(protocols)-1].Code] + for _, p := range protocols { + transport, ok := s.transports.m[p.Code] + if !ok { + continue + } + if transport.Proxy() { + selected = transport + } + } + return selected +} + +// AddTransport adds a transport to this swarm. +// +// Satisfies the Network interface from go-libp2p-transport. +func (s *Swarm) AddTransport(t transport.Transport) error { + protocols := t.Protocols() + + if len(protocols) == 0 { + return fmt.Errorf("useless transport handles no protocols: %T", t) + } + + s.transports.Lock() + defer s.transports.Unlock() + if s.transports.m == nil { + return ErrSwarmClosed + } + var registered []string + for _, p := range protocols { + if _, ok := s.transports.m[p]; ok { + proto := ma.ProtocolWithCode(p) + name := proto.Name + if name == "" { + name = fmt.Sprintf("unknown (%d)", p) + } + registered = append(registered, name) + } + } + if len(registered) > 0 { + return fmt.Errorf( + "transports already registered for protocol(s): %s", + strings.Join(registered, ", "), + ) + } + + for _, p := range protocols { + s.transports.m[p] = t + } + return nil +} diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go new file mode 100644 index 0000000000..30d8f22055 --- /dev/null +++ b/p2p/net/swarm/testing/testing.go @@ -0,0 +1,243 @@ +package testing + +import ( + "testing" + "time" + + "github.com/libp2p/go-libp2p/p2p/net/swarm" + + "github.com/libp2p/go-libp2p-core/connmgr" + "github.com/libp2p/go-libp2p-core/control" + "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/metrics" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peerstore" + "github.com/libp2p/go-libp2p-core/sec/insecure" + "github.com/libp2p/go-libp2p-core/transport" + "github.com/libp2p/go-tcp-transport" + + csms "github.com/libp2p/go-conn-security-multistream" + "github.com/libp2p/go-libp2p-peerstore/pstoremem" + quic "github.com/libp2p/go-libp2p-quic-transport" + tnet "github.com/libp2p/go-libp2p-testing/net" + tptu "github.com/libp2p/go-libp2p-transport-upgrader" + yamux "github.com/libp2p/go-libp2p-yamux" + msmux "github.com/libp2p/go-stream-muxer-multistream" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +type config struct { + disableReuseport bool + dialOnly bool + disableTCP bool + disableQUIC bool + dialTimeout time.Duration + connectionGater connmgr.ConnectionGater + rcmgr network.ResourceManager + sk crypto.PrivKey +} + +// Option is an option that can be passed when constructing a test swarm. +type Option func(*testing.T, *config) + +// OptDisableReuseport disables reuseport in this test swarm. +var OptDisableReuseport Option = func(_ *testing.T, c *config) { + c.disableReuseport = true +} + +// OptDialOnly prevents the test swarm from listening. +var OptDialOnly Option = func(_ *testing.T, c *config) { + c.dialOnly = true +} + +// OptDisableTCP disables TCP. +var OptDisableTCP Option = func(_ *testing.T, c *config) { + c.disableTCP = true +} + +// OptDisableQUIC disables QUIC. +var OptDisableQUIC Option = func(_ *testing.T, c *config) { + c.disableQUIC = true +} + +// OptConnGater configures the given connection gater on the test +func OptConnGater(cg connmgr.ConnectionGater) Option { + return func(_ *testing.T, c *config) { + c.connectionGater = cg + } +} + +func OptResourceManager(rcmgr network.ResourceManager) Option { + return func(_ *testing.T, c *config) { + c.rcmgr = rcmgr + } +} + +// OptPeerPrivateKey configures the peer private key which is then used to derive the public key and peer ID. +func OptPeerPrivateKey(sk crypto.PrivKey) Option { + return func(_ *testing.T, c *config) { + c.sk = sk + } +} + +func DialTimeout(t time.Duration) Option { + return func(_ *testing.T, c *config) { + c.dialTimeout = t + } +} + +// GenUpgrader creates a new connection upgrader for use with this swarm. +func GenUpgrader(t *testing.T, n *swarm.Swarm, opts ...tptu.Option) transport.Upgrader { + id := n.LocalPeer() + pk := n.Peerstore().PrivKey(id) + secMuxer := new(csms.SSMuxer) + secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, pk)) + + stMuxer := msmux.NewBlankTransport() + stMuxer.AddTransport("/yamux/1.0.0", yamux.DefaultTransport) + u, err := tptu.New(secMuxer, stMuxer, opts...) + require.NoError(t, err) + return u +} + +// GenSwarm generates a new test swarm. +func GenSwarm(t *testing.T, opts ...Option) *swarm.Swarm { + var cfg config + for _, o := range opts { + o(t, &cfg) + } + + var p tnet.PeerNetParams + if cfg.sk == nil { + p = tnet.RandPeerNetParamsOrFatal(t) + } else { + pk := cfg.sk.GetPublic() + id, err := peer.IDFromPublicKey(pk) + if err != nil { + t.Fatal(err) + } + p.PrivKey = cfg.sk + p.PubKey = pk + p.ID = id + p.Addr = tnet.ZeroLocalTCPAddress + } + + ps, err := pstoremem.NewPeerstore() + require.NoError(t, err) + ps.AddPubKey(p.ID, p.PubKey) + ps.AddPrivKey(p.ID, p.PrivKey) + t.Cleanup(func() { ps.Close() }) + + swarmOpts := []swarm.Option{swarm.WithMetrics(metrics.NewBandwidthCounter())} + if cfg.connectionGater != nil { + swarmOpts = append(swarmOpts, swarm.WithConnectionGater(cfg.connectionGater)) + } + if cfg.rcmgr != nil { + swarmOpts = append(swarmOpts, swarm.WithResourceManager(cfg.rcmgr)) + } + if cfg.dialTimeout != 0 { + swarmOpts = append(swarmOpts, swarm.WithDialTimeout(cfg.dialTimeout)) + } + s, err := swarm.NewSwarm(p.ID, ps, swarmOpts...) + require.NoError(t, err) + + upgrader := GenUpgrader(t, s, tptu.WithConnectionGater(cfg.connectionGater)) + + if !cfg.disableTCP { + var tcpOpts []tcp.Option + if cfg.disableReuseport { + tcpOpts = append(tcpOpts, tcp.DisableReuseport()) + } + tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...) + require.NoError(t, err) + if err := s.AddTransport(tcpTransport); err != nil { + t.Fatal(err) + } + if !cfg.dialOnly { + if err := s.Listen(p.Addr); err != nil { + t.Fatal(err) + } + } + } + if !cfg.disableQUIC { + quicTransport, err := quic.NewTransport(p.PrivKey, nil, cfg.connectionGater, nil) + if err != nil { + t.Fatal(err) + } + if err := s.AddTransport(quicTransport); err != nil { + t.Fatal(err) + } + if !cfg.dialOnly { + if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic")); err != nil { + t.Fatal(err) + } + } + } + if !cfg.dialOnly { + s.Peerstore().AddAddrs(p.ID, s.ListenAddresses(), peerstore.PermanentAddrTTL) + } + return s +} + +// DivulgeAddresses adds swarm a's addresses to swarm b's peerstore. +func DivulgeAddresses(a, b network.Network) { + id := a.LocalPeer() + addrs := a.Peerstore().Addrs(id) + b.Peerstore().AddAddrs(id, addrs, peerstore.PermanentAddrTTL) +} + +// MockConnectionGater is a mock connection gater to be used by the tests. +type MockConnectionGater struct { + Dial func(p peer.ID, addr ma.Multiaddr) bool + PeerDial func(p peer.ID) bool + Accept func(c network.ConnMultiaddrs) bool + Secured func(network.Direction, peer.ID, network.ConnMultiaddrs) bool + Upgraded func(c network.Conn) (bool, control.DisconnectReason) +} + +func DefaultMockConnectionGater() *MockConnectionGater { + m := &MockConnectionGater{} + m.Dial = func(p peer.ID, addr ma.Multiaddr) bool { + return true + } + + m.PeerDial = func(p peer.ID) bool { + return true + } + + m.Accept = func(c network.ConnMultiaddrs) bool { + return true + } + + m.Secured = func(network.Direction, peer.ID, network.ConnMultiaddrs) bool { + return true + } + + m.Upgraded = func(c network.Conn) (bool, control.DisconnectReason) { + return true, 0 + } + + return m +} + +func (m *MockConnectionGater) InterceptAddrDial(p peer.ID, addr ma.Multiaddr) (allow bool) { + return m.Dial(p, addr) +} + +func (m *MockConnectionGater) InterceptPeerDial(p peer.ID) (allow bool) { + return m.PeerDial(p) +} + +func (m *MockConnectionGater) InterceptAccept(c network.ConnMultiaddrs) (allow bool) { + return m.Accept(c) +} + +func (m *MockConnectionGater) InterceptSecured(d network.Direction, p peer.ID, c network.ConnMultiaddrs) (allow bool) { + return m.Secured(d, p, c) +} + +func (m *MockConnectionGater) InterceptUpgraded(tc network.Conn) (allow bool, reason control.DisconnectReason) { + return m.Upgraded(tc) +} diff --git a/p2p/net/swarm/testing/testing_test.go b/p2p/net/swarm/testing/testing_test.go new file mode 100644 index 0000000000..ef62570224 --- /dev/null +++ b/p2p/net/swarm/testing/testing_test.go @@ -0,0 +1,13 @@ +package testing + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGenSwarm(t *testing.T) { + swarm := GenSwarm(t) + require.NoError(t, swarm.Close()) + GenUpgrader(t, swarm) +} diff --git a/p2p/net/swarm/transport_test.go b/p2p/net/swarm/transport_test.go new file mode 100644 index 0000000000..3c863b23e7 --- /dev/null +++ b/p2p/net/swarm/transport_test.go @@ -0,0 +1,71 @@ +package swarm_test + +import ( + "context" + "testing" + + "github.com/libp2p/go-libp2p/p2p/net/swarm" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/transport" + + ma "github.com/multiformats/go-multiaddr" + + "github.com/stretchr/testify/require" +) + +type dummyTransport struct { + protocols []int + proxy bool + closed bool +} + +func (dt *dummyTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) { + panic("unimplemented") +} + +func (dt *dummyTransport) CanDial(addr ma.Multiaddr) bool { + panic("unimplemented") +} + +func (dt *dummyTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { + panic("unimplemented") +} + +func (dt *dummyTransport) Proxy() bool { + return dt.proxy +} + +func (dt *dummyTransport) Protocols() []int { + return dt.protocols +} +func (dt *dummyTransport) Close() error { + dt.closed = true + return nil +} + +func TestUselessTransport(t *testing.T) { + s := swarmt.GenSwarm(t) + require.Error(t, s.AddTransport(new(dummyTransport)), "adding a transport that supports no protocols should have failed") +} + +func TestTransportClose(t *testing.T) { + s := swarmt.GenSwarm(t) + tpt := &dummyTransport{protocols: []int{1}} + require.NoError(t, s.AddTransport(tpt)) + _ = s.Close() + if !tpt.closed { + t.Fatal("expected transport to be closed") + } +} + +func TestTransportAfterClose(t *testing.T) { + s := swarmt.GenSwarm(t) + s.Close() + + tpt := &dummyTransport{protocols: []int{1}} + if err := s.AddTransport(tpt); err != swarm.ErrSwarmClosed { + t.Fatal("expected swarm closed error, got: ", err) + } +} diff --git a/p2p/net/swarm/util_test.go b/p2p/net/swarm/util_test.go new file mode 100644 index 0000000000..11124adb27 --- /dev/null +++ b/p2p/net/swarm/util_test.go @@ -0,0 +1,53 @@ +package swarm + +import ( + "fmt" + "testing" + + "github.com/libp2p/go-libp2p-core/test" + ma "github.com/multiformats/go-multiaddr" + + "github.com/stretchr/testify/require" +) + +func TestIsFdConsuming(t *testing.T) { + tcs := map[string]struct { + addr string + isFdConsuming bool + }{ + "tcp": { + addr: "/ip4/127.0.0.1/tcp/20", + isFdConsuming: true, + }, + "quic": { + addr: "/ip4/127.0.0.1/udp/0/quic", + isFdConsuming: false, + }, + "addr-without-registered-transport": { + addr: "/ip4/127.0.0.1/tcp/20/ws", + isFdConsuming: true, + }, + "relay-tcp": { + addr: fmt.Sprintf("/ip4/127.0.0.1/tcp/20/p2p-circuit/p2p/%s", test.RandPeerIDFatal(t)), + isFdConsuming: true, + }, + "relay-quic": { + addr: fmt.Sprintf("/ip4/127.0.0.1/udp/20/quic/p2p-circuit/p2p/%s", test.RandPeerIDFatal(t)), + isFdConsuming: false, + }, + "relay-without-serveraddr": { + addr: fmt.Sprintf("/p2p-circuit/p2p/%s", test.RandPeerIDFatal(t)), + isFdConsuming: true, + }, + "relay-without-registered-transport-server": { + addr: fmt.Sprintf("/ip4/127.0.0.1/tcp/20/ws/p2p-circuit/p2p/%s", test.RandPeerIDFatal(t)), + isFdConsuming: true, + }, + } + + for name := range tcs { + maddr, err := ma.NewMultiaddr(tcs[name].addr) + require.NoError(t, err, name) + require.Equal(t, tcs[name].isFdConsuming, isFdConsumingAddr(maddr), name) + } +} diff --git a/p2p/protocol/circuitv2/relay/relay_test.go b/p2p/protocol/circuitv2/relay/relay_test.go index f7b0f35e1d..5bd840406c 100644 --- a/p2p/protocol/circuitv2/relay/relay_test.go +++ b/p2p/protocol/circuitv2/relay/relay_test.go @@ -10,19 +10,19 @@ import ( "time" bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + "github.com/libp2p/go-libp2p/p2p/net/swarm" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/metrics" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/transport" - "github.com/libp2p/go-libp2p-core/metrics" "github.com/libp2p/go-libp2p-peerstore/pstoremem" - swarm "github.com/libp2p/go-libp2p-swarm" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" "github.com/libp2p/go-tcp-transport" ma "github.com/multiformats/go-multiaddr" ) diff --git a/p2p/protocol/identify/id_glass_test.go b/p2p/protocol/identify/id_glass_test.go index 7111f70ba6..de83d8be6b 100644 --- a/p2p/protocol/identify/id_glass_test.go +++ b/p2p/protocol/identify/id_glass_test.go @@ -6,12 +6,11 @@ import ( "time" blhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/p2p/protocol/identify/id_test.go b/p2p/protocol/identify/id_test.go index 0d946c2f8e..10f274374f 100644 --- a/p2p/protocol/identify/id_test.go +++ b/p2p/protocol/identify/id_test.go @@ -12,6 +12,7 @@ import ( "github.com/libp2p/go-libp2p" blhost "github.com/libp2p/go-libp2p/p2p/host/blank" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p/p2p/protocol/identify" pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb" @@ -27,7 +28,6 @@ import ( "github.com/libp2p/go-eventbus" "github.com/libp2p/go-libp2p-peerstore/pstoremem" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" "github.com/libp2p/go-libp2p-testing/race" "github.com/libp2p/go-msgio/protoio" diff --git a/p2p/protocol/identify/peer_loop_test.go b/p2p/protocol/identify/peer_loop_test.go index 560495f21d..c6bbbd3fc4 100644 --- a/p2p/protocol/identify/peer_loop_test.go +++ b/p2p/protocol/identify/peer_loop_test.go @@ -6,12 +6,11 @@ import ( "time" blhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" - "github.com/stretchr/testify/require" ) diff --git a/p2p/protocol/ping/ping_test.go b/p2p/protocol/ping/ping_test.go index e7b5964780..75021b1406 100644 --- a/p2p/protocol/ping/ping_test.go +++ b/p2p/protocol/ping/ping_test.go @@ -7,10 +7,11 @@ import ( "github.com/stretchr/testify/require" - "github.com/libp2p/go-libp2p-core/peer" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" bhost "github.com/libp2p/go-libp2p/p2p/host/basic" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p/p2p/protocol/ping" + + "github.com/libp2p/go-libp2p-core/peer" ) func TestPing(t *testing.T) { diff --git a/p2p/test/backpressure/backpressure_test.go b/p2p/test/backpressure/backpressure_test.go index b5d6a548c4..20336f02e8 100644 --- a/p2p/test/backpressure/backpressure_test.go +++ b/p2p/test/backpressure/backpressure_test.go @@ -6,13 +6,14 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" + bhost "github.com/libp2p/go-libp2p/p2p/host/basic" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" - logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/protocol" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" - bhost "github.com/libp2p/go-libp2p/p2p/host/basic" + + logging "github.com/ipfs/go-log/v2" + "github.com/stretchr/testify/require" ) var log = logging.Logger("backpressure") diff --git a/p2p/test/reconnects/reconnect_test.go b/p2p/test/reconnects/reconnect_test.go index 0d71ec5c02..72523ffd26 100644 --- a/p2p/test/reconnects/reconnect_test.go +++ b/p2p/test/reconnects/reconnect_test.go @@ -10,14 +10,13 @@ import ( "time" bhost "github.com/libp2p/go-libp2p/p2p/host/basic" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" - swarmt "github.com/libp2p/go-libp2p-swarm/testing" - "github.com/stretchr/testify/require" )