diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index 06ea31b2e9..c0f02e2abc 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -8,6 +8,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/muxer/yamux" csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/network" @@ -17,7 +18,6 @@ import ( "github.com/libp2p/go-libp2p-core/transport" mocknetwork "github.com/libp2p/go-libp2p-testing/mocks/network" - ttransport "github.com/libp2p/go-libp2p-testing/suites/transport" ma "github.com/multiformats/go-multiaddr" diff --git a/p2p/transport/testsuite/stream_suite.go b/p2p/transport/testsuite/stream_suite.go new file mode 100644 index 0000000000..e7770bfa71 --- /dev/null +++ b/p2p/transport/testsuite/stream_suite.go @@ -0,0 +1,448 @@ +package ttransport + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "os" + "strconv" + "sync" + "testing" + "time" + + crand "crypto/rand" + mrand "math/rand" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/transport" + "github.com/libp2p/go-libp2p-testing/race" + + ma "github.com/multiformats/go-multiaddr" +) + +var randomness []byte + +var StressTestTimeout = 1 * time.Minute + +func init() { + // read 1MB of randomness + randomness = make([]byte, 1<<20) + if _, err := crand.Read(randomness); err != nil { + panic(err) + } + + if timeout := os.Getenv("TEST_STRESS_TIMEOUT_MS"); timeout != "" { + if v, err := strconv.ParseInt(timeout, 10, 32); err == nil { + StressTestTimeout = time.Duration(v) * time.Millisecond + } + } +} + +type Options struct { + ConnNum int + StreamNum int + MsgNum int + MsgMin int + MsgMax int +} + +func fullClose(t *testing.T, s network.MuxedStream) { + if err := s.CloseWrite(); err != nil { + t.Error(err) + s.Reset() + return + } + b, err := ioutil.ReadAll(s) + if err != nil { + t.Error(err) + } + if len(b) != 0 { + t.Error("expected to be done reading") + } + if err := s.Close(); err != nil { + t.Error(err) + } +} + +func randBuf(size int) []byte { + n := len(randomness) - size + if size < 1 { + panic(fmt.Errorf("requested too large buffer (%d). max is %d", size, len(randomness))) + } + + start := mrand.Intn(n) + return randomness[start : start+size] +} + +func echoStream(t *testing.T, s network.MuxedStream) { + // echo everything + if _, err := io.Copy(s, s); err != nil { + t.Error(err) + } +} + +func echo(t *testing.T, c transport.CapableConn) { + var wg sync.WaitGroup + defer wg.Wait() + for { + str, err := c.AcceptStream() + if err != nil { + break + } + wg.Add(1) + go func() { + defer wg.Done() + defer str.Close() + echoStream(t, str) + }() + } +} + +func serve(t *testing.T, l transport.Listener) { + var wg sync.WaitGroup + defer wg.Wait() + + for { + c, err := l.Accept() + if err != nil { + return + } + defer c.Close() + + wg.Add(1) + go func() { + defer wg.Done() + echo(t, c) + }() + } +} + +func SubtestStress(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID, opt Options) { + msgsize := 1 << 11 + + rateLimitN := 5000 // max of 5k funcs, because -race has 8k max. + rateLimitChan := make(chan struct{}, rateLimitN) + for i := 0; i < rateLimitN; i++ { + rateLimitChan <- struct{}{} + } + + rateLimit := func(f func()) { + <-rateLimitChan + f() + rateLimitChan <- struct{}{} + } + + writeStream := func(s network.MuxedStream, bufs chan<- []byte) { + for i := 0; i < opt.MsgNum; i++ { + buf := randBuf(msgsize) + bufs <- buf + if _, err := s.Write(buf); err != nil { + t.Errorf("s.Write(buf): %s", err) + return + } + } + } + + readStream := func(s network.MuxedStream, bufs <-chan []byte) { + buf2 := make([]byte, msgsize) + i := 0 + for buf1 := range bufs { + i++ + + if _, err := io.ReadFull(s, buf2); err != nil { + t.Errorf("io.ReadFull(s, buf2): %s", err) + return + } + if !bytes.Equal(buf1, buf2) { + t.Errorf("buffers not equal (%x != %x)", buf1[:3], buf2[:3]) + return + } + } + } + + openStreamAndRW := func(c network.MuxedConn) { + s, err := c.OpenStream(context.Background()) + if err != nil { + t.Errorf("failed to create NewStream: %s", err) + return + } + + bufs := make(chan []byte, opt.MsgNum) + go func() { + writeStream(s, bufs) + close(bufs) + }() + + readStream(s, bufs) + fullClose(t, s) + } + + openConnAndRW := func() { + var wg sync.WaitGroup + defer wg.Wait() + + l, err := ta.Listen(maddr) + if err != nil { + t.Error(err) + return + } + defer l.Close() + + wg.Add(1) + go func() { + defer wg.Done() + serve(t, l) + }() + + c, err := tb.Dial(context.Background(), l.Multiaddr(), peerA) + if err != nil { + t.Error(err) + return + } + defer c.Close() + + // serve the outgoing conn, because some muxers assume + // that we _always_ call serve. (this is an error?) + wg.Add(1) + go func() { + defer wg.Done() + echo(t, c) + }() + + var openWg sync.WaitGroup + for i := 0; i < opt.StreamNum; i++ { + openWg.Add(1) + go rateLimit(func() { + defer openWg.Done() + openStreamAndRW(c) + }) + } + openWg.Wait() + } + + var wg sync.WaitGroup + defer wg.Wait() + for i := 0; i < opt.ConnNum; i++ { + wg.Add(1) + go rateLimit(func() { + defer wg.Done() + openConnAndRW() + }) + } +} + +func SubtestStreamOpenStress(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) { + l, err := ta.Listen(maddr) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + count := 10000 + workers := 5 + + if race.WithRace() { + // the race detector can only deal with 8128 simultaneous goroutines, so let's make sure we don't go overboard. + count = 1000 + } + + var ( + connA, connB transport.CapableConn + ) + + accepted := make(chan error, 1) + go func() { + var err error + connA, err = l.Accept() + accepted <- err + }() + connB, err = tb.Dial(context.Background(), l.Multiaddr(), peerA) + if err != nil { + t.Fatal(err) + } + err = <-accepted + if err != nil { + t.Fatal(err) + } + + defer func() { + if connA != nil { + connA.Close() + } + if connB != nil { + connB.Close() + } + }() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < workers; j++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < count; i++ { + s, err := connA.OpenStream(context.Background()) + if err != nil { + t.Error(err) + return + } + wg.Add(1) + go func() { + defer wg.Done() + fullClose(t, s) + }() + } + }() + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < count*workers; i++ { + str, err := connB.AcceptStream() + if err != nil { + break + } + wg.Add(1) + go func() { + defer wg.Done() + fullClose(t, str) + }() + } + }() + + timeout := time.After(StressTestTimeout) + done := make(chan struct{}) + + go func() { + wg.Wait() + close(done) + }() + + select { + case <-timeout: + t.Fatal("timed out receiving streams") + case <-done: + } +} + +func SubtestStreamReset(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) { + var wg sync.WaitGroup + defer wg.Wait() + + l, err := ta.Listen(maddr) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + wg.Add(1) + go func() { + defer wg.Done() + + muxa, err := l.Accept() + if err != nil { + t.Error(err) + return + } + defer muxa.Close() + + s, err := muxa.OpenStream(context.Background()) + if err != nil { + t.Error(err) + return + } + defer s.Close() + + // Some transports won't open the stream until we write. That's + // fine. + _, _ = s.Write([]byte("foo")) + + time.Sleep(time.Millisecond * 50) + + _, err = s.Write([]byte("bar")) + if err == nil { + t.Error("should have failed to write") + } + + }() + + muxb, err := tb.Dial(context.Background(), l.Multiaddr(), peerA) + if err != nil { + t.Fatal(err) + } + defer muxb.Close() + + str, err := muxb.AcceptStream() + if err != nil { + t.Error(err) + return + } + str.Reset() +} + +func SubtestStress1Conn1Stream1Msg(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) { + SubtestStress(t, ta, tb, maddr, peerA, Options{ + ConnNum: 1, + StreamNum: 1, + MsgNum: 1, + MsgMax: 100, + MsgMin: 100, + }) +} + +func SubtestStress1Conn1Stream100Msg(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) { + SubtestStress(t, ta, tb, maddr, peerA, Options{ + ConnNum: 1, + StreamNum: 1, + MsgNum: 100, + MsgMax: 100, + MsgMin: 100, + }) +} + +func SubtestStress1Conn100Stream100Msg(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) { + SubtestStress(t, ta, tb, maddr, peerA, Options{ + ConnNum: 1, + StreamNum: 100, + MsgNum: 100, + MsgMax: 100, + MsgMin: 100, + }) +} + +func SubtestStress50Conn10Stream50Msg(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) { + SubtestStress(t, ta, tb, maddr, peerA, Options{ + ConnNum: 50, + StreamNum: 10, + MsgNum: 50, + MsgMax: 100, + MsgMin: 100, + }) +} + +func SubtestStress1Conn1000Stream10Msg(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) { + SubtestStress(t, ta, tb, maddr, peerA, Options{ + ConnNum: 1, + StreamNum: 1000, + MsgNum: 10, + MsgMax: 100, + MsgMin: 100, + }) +} + +func SubtestStress1Conn100Stream100Msg10MB(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) { + SubtestStress(t, ta, tb, maddr, peerA, Options{ + ConnNum: 1, + StreamNum: 100, + MsgNum: 100, + MsgMax: 10000, + MsgMin: 1000, + }) +} diff --git a/p2p/transport/testsuite/transport_suite.go b/p2p/transport/testsuite/transport_suite.go new file mode 100644 index 0000000000..6e6b300543 --- /dev/null +++ b/p2p/transport/testsuite/transport_suite.go @@ -0,0 +1,305 @@ +package ttransport + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "sync" + "testing" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/transport" + + ma "github.com/multiformats/go-multiaddr" +) + +var testData = []byte("this is some test data") + +func SubtestProtocols(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) { + rawIPAddr, _ := ma.NewMultiaddr("/ip4/1.2.3.4") + if ta.CanDial(rawIPAddr) || tb.CanDial(rawIPAddr) { + t.Error("nothing should be able to dial raw IP") + } + + tprotos := make(map[int]bool) + for _, p := range ta.Protocols() { + tprotos[p] = true + } + + if !ta.Proxy() { + protos := maddr.Protocols() + proto := protos[len(protos)-1] + if !tprotos[proto.Code] { + t.Errorf("transport should have reported that it supports protocol '%s' (%d)", proto.Name, proto.Code) + } + } else { + found := false + for _, proto := range maddr.Protocols() { + if tprotos[proto.Code] { + found = true + break + } + } + if !found { + t.Errorf("didn't find any matching proxy protocols in maddr: %s", maddr) + } + } +} + +func SubtestBasic(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + list, err := ta.Listen(maddr) + if err != nil { + t.Fatal(err) + } + defer list.Close() + + var ( + connA, connB transport.CapableConn + done = make(chan struct{}) + ) + defer func() { + <-done + if connA != nil { + connA.Close() + } + if connB != nil { + connB.Close() + } + }() + + go func() { + defer close(done) + var err error + connB, err = list.Accept() + if err != nil { + t.Error(err) + return + } + s, err := connB.AcceptStream() + if err != nil { + t.Error(err) + return + } + + buf, err := ioutil.ReadAll(s) + if err != nil { + t.Error(err) + return + } + + if !bytes.Equal(testData, buf) { + t.Errorf("expected %s, got %s", testData, buf) + } + + n, err := s.Write(testData) + if err != nil { + t.Error(err) + return + } + if n != len(testData) { + t.Error(err) + return + } + + err = s.Close() + if err != nil { + t.Error(err) + } + }() + + if !tb.CanDial(list.Multiaddr()) { + t.Error("CanDial should have returned true") + } + + connA, err = tb.Dial(ctx, list.Multiaddr(), peerA) + if err != nil { + t.Fatal(err) + } + + s, err := connA.OpenStream(context.Background()) + if err != nil { + t.Fatal(err) + } + + n, err := s.Write(testData) + if err != nil { + t.Fatal(err) + return + } + + if n != len(testData) { + t.Fatalf("failed to write enough data (a->b)") + return + } + + if err = s.CloseWrite(); err != nil { + t.Fatal(err) + return + } + + buf, err := ioutil.ReadAll(s) + if err != nil { + t.Fatal(err) + return + } + if !bytes.Equal(testData, buf) { + t.Errorf("expected %s, got %s", testData, buf) + } + + if err = s.Close(); err != nil { + t.Fatal(err) + return + } +} + +func SubtestPingPong(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) { + streams := 100 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + list, err := ta.Listen(maddr) + if err != nil { + t.Fatal(err) + } + defer list.Close() + + var ( + connA, connB transport.CapableConn + ) + defer func() { + if connA != nil { + connA.Close() + } + if connB != nil { + connB.Close() + } + }() + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + var err error + connA, err = list.Accept() + if err != nil { + t.Error(err) + return + } + + var sWg sync.WaitGroup + for i := 0; i < streams; i++ { + s, err := connA.AcceptStream() + if err != nil { + t.Error(err) + return + } + + sWg.Add(1) + go func() { + defer sWg.Done() + + data, err := ioutil.ReadAll(s) + if err != nil { + s.Reset() + t.Error(err) + return + } + if !bytes.HasPrefix(data, testData) { + t.Errorf("expected %q to have prefix %q", string(data), string(testData)) + } + + n, err := s.Write(data) + if err != nil { + s.Reset() + t.Error(err) + return + } + + if n != len(data) { + s.Reset() + t.Error(err) + return + } + s.Close() + }() + } + sWg.Wait() + }() + + if !tb.CanDial(list.Multiaddr()) { + t.Error("CanDial should have returned true") + } + + connB, err = tb.Dial(ctx, list.Multiaddr(), peerA) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < streams; i++ { + s, err := connB.OpenStream(context.Background()) + if err != nil { + t.Error(err) + continue + } + + wg.Add(1) + go func(i int) { + defer wg.Done() + data := []byte(fmt.Sprintf("%s - %d", testData, i)) + n, err := s.Write(data) + if err != nil { + s.Reset() + t.Error(err) + return + } + + if n != len(data) { + s.Reset() + t.Error("failed to write enough data (a->b)") + return + } + if err = s.CloseWrite(); err != nil { + t.Error(err) + return + } + + ret, err := ioutil.ReadAll(s) + if err != nil { + s.Reset() + t.Error(err) + return + } + if !bytes.Equal(data, ret) { + t.Errorf("expected %q, got %q", string(data), string(ret)) + } + + if err = s.Close(); err != nil { + t.Error(err) + return + } + }(i) + } + wg.Wait() +} + +func SubtestCancel(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) { + list, err := ta.Listen(maddr) + if err != nil { + t.Fatal(err) + } + defer list.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c, err := tb.Dial(ctx, list.Multiaddr(), peerA) + if err == nil { + c.Close() + t.Fatal("dial should have failed") + } +} diff --git a/p2p/transport/testsuite/utils_suite.go b/p2p/transport/testsuite/utils_suite.go new file mode 100644 index 0000000000..1d520ff262 --- /dev/null +++ b/p2p/transport/testsuite/utils_suite.go @@ -0,0 +1,45 @@ +package ttransport + +import ( + "reflect" + "runtime" + "testing" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/transport" + + ma "github.com/multiformats/go-multiaddr" +) + +var Subtests = []func(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID){ + SubtestProtocols, + SubtestBasic, + SubtestCancel, + SubtestPingPong, + + // Stolen from the stream muxer test suite. + SubtestStress1Conn1Stream1Msg, + SubtestStress1Conn1Stream100Msg, + SubtestStress1Conn100Stream100Msg, + SubtestStress50Conn10Stream50Msg, + SubtestStress1Conn1000Stream10Msg, + SubtestStress1Conn100Stream100Msg10MB, + SubtestStreamOpenStress, + SubtestStreamReset, +} + +func getFunctionName(i interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() +} + +func SubtestTransport(t *testing.T, ta, tb transport.Transport, addr string, peerA peer.ID) { + maddr, err := ma.NewMultiaddr(addr) + if err != nil { + t.Fatal(err) + } + for _, f := range Subtests { + t.Run(getFunctionName(f), func(t *testing.T) { + f(t, ta, tb, maddr, peerA) + }) + } +} diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index b83f528f84..e7c7aa0f44 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -17,6 +17,7 @@ import ( csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/network" @@ -27,8 +28,6 @@ import ( "github.com/libp2p/go-libp2p-core/transport" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" - ttransport "github.com/libp2p/go-libp2p-testing/suites/transport" - ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" )