diff --git a/core/core.go b/core/core.go index 8a381f0b44a..6e6cf8dae79 100644 --- a/core/core.go +++ b/core/core.go @@ -106,7 +106,7 @@ func NewIpfsNode(cfg *config.Config, online bool) (*IpfsNode, error) { return nil, err } - net, err = inet.NewIpfsNetwork(context.TODO(), local, &mux.ProtocolMap{ + net, err = inet.NewIpfsNetwork(context.TODO(), local, peerstore, &mux.ProtocolMap{ mux.ProtocolID_Routing: dhtService, mux.ProtocolID_Exchange: exchangeService, }) diff --git a/core/core_test.go b/core/core_test.go index c6695eb6b63..60555c845af 100644 --- a/core/core_test.go +++ b/core/core_test.go @@ -7,9 +7,8 @@ import ( ) func TestInitialization(t *testing.T) { - id := &config.Identity{ + id := config.Identity{ PeerID: "QmNgdzLieYi8tgfo2WfTUzNVH5hQK9oAYGVf6dxN12NrHt", - Address: "/ip4/127.0.0.1/tcp/8000", PrivKey: "CAASrRIwggkpAgEAAoICAQCwt67GTUQ8nlJhks6CgbLKOx7F5tl1r9zF4m3TUrG3Pe8h64vi+ILDRFd7QJxaJ/n8ux9RUDoxLjzftL4uTdtv5UXl2vaufCc/C0bhCRvDhuWPhVsD75/DZPbwLsepxocwVWTyq7/ZHsCfuWdoh/KNczfy+Gn33gVQbHCnip/uhTVxT7ARTiv8Qa3d7qmmxsR+1zdL/IRO0mic/iojcb3Oc/PRnYBTiAZFbZdUEit/99tnfSjMDg02wRayZaT5ikxa6gBTMZ16Yvienq7RwSELzMQq2jFA4i/TdiGhS9uKywltiN2LrNDBcQJSN02pK12DKoiIy+wuOCRgs2NTQEhU2sXCk091v7giTTOpFX2ij9ghmiRfoSiBFPJA5RGwiH6ansCHtWKY1K8BS5UORM0o3dYk87mTnKbCsdz4bYnGtOWafujYwzueGx8r+IWiys80IPQKDeehnLW6RgoyjszKgL/2XTyP54xMLSW+Qb3BPgDcPaPO0hmop1hW9upStxKsefW2A2d46Ds4HEpJEry7PkS5M4gKL/zCKHuxuXVk14+fZQ1rstMuvKjrekpAC2aVIKMI9VRA3awtnje8HImQMdj+r+bPmv0N8rTTr3eS4J8Yl7k12i95LLfK+fWnmUh22oTNzkRlaiERQrUDyE4XNCtJc0xs1oe1yXGqazCIAQIDAQABAoICAQCk1N/ftahlRmOfAXk//8wNl7FvdJD3le6+YSKBj0uWmN1ZbUSQk64chr12iGCOM2WY180xYjy1LOS44PTXaeW5bEiTSnb3b3SH+HPHaWCNM2EiSogHltYVQjKW+3tfH39vlOdQ9uQ+l9Gh6iTLOqsCRyszpYPqIBwi1NMLY2Ej8PpVU7ftnFWouHZ9YKS7nAEiMoowhTu/7cCIVwZlAy3AySTuKxPMVj9LORqC32PVvBHZaMPJ+X1Xyijqg6aq39WyoztkXg3+Xxx5j5eOrK6vO/Lp6ZUxaQilHDXoJkKEJjgIBDZpluss08UPfOgiWAGkW+L4fgUxY0qDLDAEMhyEBAn6KOKVL1JhGTX6GjhWziI94bddSpHKYOEIDzUy4H8BXnKhtnyQV6ELS65C2hj9D0IMBTj7edCF1poJy0QfdK0cuXgMvxHLeUO5uc2YWfbNosvKxqygB9rToy4b22YvNwsZUXsTY6Jt+p9V2OgXSKfB5VPeRbjTJL6xqvvUJpQytmII/C9JmSDUtCbYceHj6X9jgigLk20VV6nWHqCTj3utXD6NPAjoycVpLKDlnWEgfVELDIk0gobxUqqSm3jTPEKRPJgxkgPxbwxYumtw++1UY2y35w3WRDc2xYPaWKBCQeZy+mL6ByXp9bWlNvxS3Knb6oZp36/ovGnf2pGvdQKCAQEAyKpipz2lIUySDyE0avVWAmQb2tWGKXALPohzj7AwkcfEg2GuwoC6GyVE2sTJD1HRazIjOKn3yQORg2uOPeG7sx7EKHxSxCKDrbPawkvLCq8JYSy9TLvhqKUVVGYPqMBzu2POSLEA81QXas+aYjKOFWA2Zrjq26zV9ey3+6Lc6WULePgRQybU8+RHJc6fdjUCCfUxgOrUO2IQOuTJ+FsDpVnrMUGlokmWn23OjL4qTL9wGDnWGUs2pjSzNbj3qA0d8iqaiMUyHX/D/VS0wpeT1osNBSm8suvSibYBn+7wbIApbwXUxZaxMv2OHGz3empae4ckvNZs7r8wsI9UwFt8mwKCAQEA4XK6gZkv9t+3YCcSPw2ensLvL/xU7i2bkC9tfTGdjnQfzZXIf5KNdVuj/SerOl2S1s45NMs3ysJbADwRb4ahElD/V71nGzV8fpFTitC20ro9fuX4J0+twmBolHqeH9pmeGTjAeL1rvt6vxs4FkeG/yNft7GdXpXTtEGaObn8Mt0tPY+aB3UnKrnCQoQAlPyGHFrVRX0UEcp6wyyNGhJCNKeNOvqCHTFObhbhO+KWpWSN0MkVHnqaIBnIn1Te8FtvP/iTwXGnKc0YXJUG6+LM6LmOguW6tg8ZqiQeYyyR+e9eCFH4csLzkrTl1GxCxwEsoSLIMm7UDcjttW6tYEghkwKCAQEAmeCO5lCPYImnN5Lu71ZTLmI2OgmjaANTnBBnDbi+hgv61gUCToUIMejSdDCTPfwv61P3TmyIZs0luPGxkiKYHTNqmOE9Vspgz8Mr7fLRMNApESuNvloVIY32XVImj/GEzh4rAfM6F15U1sN8T/EUo6+0B/Glp+9R49QzAfRSE2g48/rGwgf1JVHYfVWFUtAzUA+GdqWdOixo5cCsYJbqpNHfWVZN/bUQnBFIYwUwysnC29D+LUdQEQQ4qOm+gFAOtrWU62zMkXJ4iLt8Ify6kbrvsRXgbhQIzzGS7WH9XDarj0eZciuslr15TLMC1Azadf+cXHLR9gMHA13mT9vYIQKCAQA/DjGv8cKCkAvf7s2hqROGYAs6Jp8yhrsN1tYOwAPLRhtnCs+rLrg17M2vDptLlcRuI/vIElamdTmylRpjUQpX7yObzLO73nfVhpwRJVMdGU394iBIDncQ+JoHfUwgqJskbUM40dvZdyjbrqc/Q/4z+hbZb+oN/GXb8sVKBATPzSDMKQ/xqgisYIw+wmDPStnPsHAaIWOtni47zIgilJzD0WEk78/YjmPbUrboYvWziK5JiRRJFA1rkQqV1c0M+OXixIm+/yS8AksgCeaHr0WUieGcJtjT9uE8vyFop5ykhRiNxy9wGaq6i7IEecsrkd6DqxDHWkwhFuO1bSE83q/VAoIBAEA+RX1i/SUi08p71ggUi9WFMqXmzELp1L3hiEjOc2AklHk2rPxsaTh9+G95BvjhP7fRa/Yga+yDtYuyjO99nedStdNNSg03aPXILl9gs3r2dPiQKUEXZJ3FrH6tkils/8BlpOIRfbkszrdZIKTO9GCdLWQ30dQITDACs8zV/1GFGrHFrqnnMe/NpIFHWNZJ0/WZMi8wgWO6Ik8jHEpQtVXRiXLqy7U6hk170pa4GHOzvftfPElOZZjy9qn7KjdAQqy6spIrAE94OEL+fBgbHQZGLpuTlj6w6YGbMtPU8uo7sXKoc6WOCb68JWft3tejGLDa1946HAWqVM9B/UcneNc=", } @@ -19,6 +18,10 @@ func TestInitialization(t *testing.T) { Datastore: config.Datastore{ Type: "memory", }, + Addresses: config.Addresses{ + Swarm: "/ip4/0.0.0.0/tcp/4001", + API: "/ip4/127.0.0.1/tcp/8000", + }, }, &config.Config{ @@ -27,6 +30,10 @@ func TestInitialization(t *testing.T) { Type: "leveldb", Path: ".testdb", }, + Addresses: config.Addresses{ + Swarm: "/ip4/0.0.0.0/tcp/4001", + API: "/ip4/127.0.0.1/tcp/8000", + }, }, } diff --git a/crypto/key.go b/crypto/key.go index 38b3b0ebdca..f0a35c698ab 100644 --- a/crypto/key.go +++ b/crypto/key.go @@ -23,7 +23,17 @@ const ( RSA = iota ) +type Key interface { + // Bytes returns a serialized, storeable representation of this key + Bytes() ([]byte, error) + + // Equals checks whether two PubKeys are the same + Equals(Key) bool +} + type PrivKey interface { + Key + // Cryptographically sign the given bytes Sign([]byte) ([]byte, error) @@ -32,17 +42,13 @@ type PrivKey interface { // Generate a secret string of bytes GenSecret() []byte - - // Bytes returns a serialized, storeable representation of this key - Bytes() ([]byte, error) } type PubKey interface { + Key + // Verify that 'sig' is the signed hash of 'data' Verify(data []byte, sig []byte) (bool, error) - - // Bytes returns a serialized, storeable representation of this key - Bytes() ([]byte, error) } // Given a public key, generates the shared key. @@ -229,3 +235,14 @@ func UnmarshalPrivateKey(data []byte) (PrivKey, error) { return nil, ErrBadKeyType } } + +// KeyEqual checks whether two +func KeyEqual(k1, k2 Key) bool { + if k1 == k2 { + return true + } + + b1, err1 := k1.Bytes() + b2, err2 := k2.Bytes() + return bytes.Equal(b1, b2) && err1 == err2 +} diff --git a/crypto/key_test.go b/crypto/key_test.go index c002c581965..13c94215e80 100644 --- a/crypto/key_test.go +++ b/crypto/key_test.go @@ -3,12 +3,14 @@ package crypto import "testing" func TestRsaKeys(t *testing.T) { - sk, _, err := GenerateKeyPair(RSA, 512) + sk, pk, err := GenerateKeyPair(RSA, 512) if err != nil { t.Fatal(err) } testKeySignature(t, sk) testKeyEncoding(t, sk) + testKeyEquals(t, sk) + testKeyEquals(t, pk) } func testKeySignature(t *testing.T, sk PrivKey) { @@ -52,3 +54,41 @@ func testKeyEncoding(t *testing.T, sk PrivKey) { t.Fatal(err) } } + +func testKeyEquals(t *testing.T, k Key) { + kb, err := k.Bytes() + if err != nil { + t.Fatal(err) + } + + if !KeyEqual(k, k) { + t.Fatal("Key not equal to itself.") + } + + if !KeyEqual(k, testkey(kb)) { + t.Fatal("Key not equal to key with same bytes.") + } + + sk, pk, err := GenerateKeyPair(RSA, 512) + if err != nil { + t.Fatal(err) + } + + if KeyEqual(k, sk) { + t.Fatal("Keys should not equal.") + } + + if KeyEqual(k, pk) { + t.Fatal("Keys should not equal.") + } +} + +type testkey []byte + +func (pk testkey) Bytes() ([]byte, error) { + return pk, nil +} + +func (pk testkey) Equals(k Key) bool { + return KeyEqual(pk, k) +} diff --git a/crypto/rsa.go b/crypto/rsa.go index 513b868d171..e582b59c297 100644 --- a/crypto/rsa.go +++ b/crypto/rsa.go @@ -41,6 +41,11 @@ func (pk *RsaPublicKey) Bytes() ([]byte, error) { return proto.Marshal(pbmes) } +// Equals checks whether this key is equal to another +func (pk *RsaPublicKey) Equals(k Key) bool { + return KeyEqual(pk, k) +} + func (sk *RsaPrivateKey) GenSecret() []byte { buf := make([]byte, 16) rand.Read(buf) @@ -65,6 +70,11 @@ func (sk *RsaPrivateKey) Bytes() ([]byte, error) { return proto.Marshal(pbmes) } +// Equals checks whether this key is equal to another +func (sk *RsaPrivateKey) Equals(k Key) bool { + return KeyEqual(sk, k) +} + func UnmarshalRsaPrivateKey(b []byte) (*RsaPrivateKey, error) { sk, err := x509.ParsePKCS1PrivateKey(b) if err != nil { diff --git a/crypto/spipe/handshake.go b/crypto/spipe/handshake.go index 8019f6fc4cc..f617c75b3aa 100644 --- a/crypto/spipe/handshake.go +++ b/crypto/spipe/handshake.go @@ -90,23 +90,18 @@ func (s *SecurePipe) handshake() error { return err } - s.remote.PubKey, err = ci.UnmarshalPublicKey(proposeResp.GetPubkey()) + // get remote identity + remotePubKey, err := ci.UnmarshalPublicKey(proposeResp.GetPubkey()) if err != nil { return err } - remoteID, err := IDFromPubKey(s.remote.PubKey) + // get or construct peer + s.remote, err = getOrConstructPeer(s.peers, remotePubKey) if err != nil { return err } - - if s.remote.ID != nil && !remoteID.Equal(s.remote.ID) { - e := "Expected pubkey does not match sent pubkey: %v - %v" - return fmt.Errorf(e, s.remote.ID.Pretty(), remoteID.Pretty()) - } else if s.remote.ID == nil { - s.remote.ID = remoteID - } - // u.POut("Remote Peer Identified as %s\n", s.remote.ID.Pretty()) + u.DOut("[%s] Remote Peer Identified as %s\n", s.local.ID.Pretty(), s.remote.ID.Pretty()) exchange, err := selectBest(SupportedExchanges, proposeResp.GetExchanges()) if err != nil { @@ -340,3 +335,52 @@ func selectBest(myPrefs, theirPrefs string) (string, error) { return "", errors.New("No algorithms in common!") } + +// getOrConstructPeer attempts to fetch a peer from a peerstore. +// if succeeds, verify ID and PubKey match. +// else, construct it. +func getOrConstructPeer(peers peer.Peerstore, rpk ci.PubKey) (*peer.Peer, error) { + + rid, err := IDFromPubKey(rpk) + if err != nil { + return nil, err + } + + npeer, err := peers.Get(rid) + if err != nil || npeer == nil { + if err != peer.ErrNotFound { + return nil, err // unexpected error happened. + } + + // dont have peer, so construct it + add it to peerstore. + npeer = &peer.Peer{ID: rid, PubKey: rpk} + if err := peers.Put(npeer); err != nil { + return nil, err + } + + // done, return the newly constructed peer. + return npeer, nil + } + + // did have it locally. + + // let's verify ID + if !npeer.ID.Equal(rid) { + e := "Expected peer.ID does not match sent pubkey's hash: %v - %v" + return nil, fmt.Errorf(e, npeer.ID.Pretty(), rid.Pretty()) + } + + if npeer.PubKey == nil { + // didn't have a pubkey, just set it. + npeer.PubKey = rpk + return npeer, nil + } + + // did have pubkey, let's verify it's really the same. + // this shouldn't ever happen, given we hashed, etc, but it could mean + // expected code (or protocol) invariants violated. + if !npeer.PubKey.Equals(rpk) { + return nil, fmt.Errorf("WARNING: PubKey mismatch: %v", npeer.ID.Pretty()) + } + return npeer, nil +} diff --git a/crypto/spipe/pipe.go b/crypto/spipe/pipe.go index caa539275ac..8d0db0d5dbe 100644 --- a/crypto/spipe/pipe.go +++ b/crypto/spipe/pipe.go @@ -20,6 +20,7 @@ type SecurePipe struct { local *peer.Peer remote *peer.Peer + peers peer.Peerstore params params @@ -32,16 +33,16 @@ type params struct { } // NewSecurePipe constructs a pipe with channels of a given buffer size. -func NewSecurePipe(ctx context.Context, bufsize int, local, - remote *peer.Peer) (*SecurePipe, error) { +func NewSecurePipe(ctx context.Context, bufsize int, local *peer.Peer, + peers peer.Peerstore) (*SecurePipe, error) { sp := &SecurePipe{ Duplex: Duplex{ In: make(chan []byte, bufsize), Out: make(chan []byte, bufsize), }, - local: local, - remote: remote, + local: local, + peers: peers, } return sp, nil } @@ -63,6 +64,16 @@ func (s *SecurePipe) Wrap(ctx context.Context, insecure Duplex) error { return nil } +// LocalPeer retrieves the local peer. +func (s *SecurePipe) LocalPeer() *peer.Peer { + return s.local +} + +// RemotePeer retrieves the local peer. +func (s *SecurePipe) RemotePeer() *peer.Peer { + return s.remote +} + // Close closes the secure pipe func (s *SecurePipe) Close() error { if s.cancel == nil { diff --git a/daemon/daemon_test.go b/daemon/daemon_test.go index b509bb76c2c..9b0f4f44bdf 100644 --- a/daemon/daemon_test.go +++ b/daemon/daemon_test.go @@ -26,9 +26,8 @@ func TestInitializeDaemonListener(t *testing.T) { privKey := base64.StdEncoding.EncodeToString(prbytes) pID := ident.Pretty() - id := &config.Identity{ + id := config.Identity{ PeerID: pID, - Address: "/ip4/127.0.0.1/tcp/8000", PrivKey: privKey, } @@ -38,6 +37,10 @@ func TestInitializeDaemonListener(t *testing.T) { Datastore: config.Datastore{ Type: "memory", }, + Addresses: config.Addresses{ + Swarm: "/ip4/0.0.0.0/tcp/4001", + API: "/ip4/127.0.0.1/tcp/8000", + }, }, &config.Config{ @@ -46,6 +49,10 @@ func TestInitializeDaemonListener(t *testing.T) { Type: "leveldb", Path: ".testdb", }, + Addresses: config.Addresses{ + Swarm: "/ip4/0.0.0.0/tcp/4001", + API: "/ip4/127.0.0.1/tcp/8000", + }, }, } diff --git a/net/conn/conn.go b/net/conn/conn.go index 4e6ded25854..645264b8da4 100644 --- a/net/conn/conn.go +++ b/net/conn/conn.go @@ -48,17 +48,6 @@ func NewConn(peer *peer.Peer, addr *ma.Multiaddr, nconn net.Conn) (*Conn, error) return conn, nil } -// NewNetConn constructs a new connection with given net.Conn -func NewNetConn(nconn net.Conn) (*Conn, error) { - - addr, err := ma.FromNetAddr(nconn.RemoteAddr()) - if err != nil { - return nil, err - } - - return NewConn(new(peer.Peer), addr, nconn) -} - // Dial connects to a particular peer, over a given network // Example: Dial("udp", peer) func Dial(network string, peer *peer.Peer) (*Conn, error) { @@ -112,3 +101,9 @@ func (c *Conn) Close() error { c.Closed <- true return err } + +// NetConnMultiaddr returns the net.Conn's address, recast as a multiaddr. +// (consider moving this directly into the multiaddr package) +func NetConnMultiaddr(nconn net.Conn) (*ma.Multiaddr, error) { + return ma.FromNetAddr(nconn.RemoteAddr()) +} diff --git a/net/mux/mux.go b/net/mux/mux.go index 57cfe334302..e02a926d35b 100644 --- a/net/mux/mux.go +++ b/net/mux/mux.go @@ -2,6 +2,7 @@ package mux import ( "errors" + "sync" msg "github.com/jbenet/go-ipfs/net/message" u "github.com/jbenet/go-ipfs/util" @@ -30,6 +31,8 @@ type Muxer struct { // cancel is the function to stop the Muxer cancel context.CancelFunc + ctx context.Context + wg sync.WaitGroup *msg.Pipe } @@ -58,11 +61,14 @@ func (m *Muxer) Start(ctx context.Context) error { } // make a cancellable context. - ctx, m.cancel = context.WithCancel(ctx) + m.ctx, m.cancel = context.WithCancel(ctx) + m.wg = sync.WaitGroup{} - go m.handleIncomingMessages(ctx) + m.wg.Add(1) + go m.handleIncomingMessages() for pid, proto := range m.Protocols { - go m.handleOutgoingMessages(ctx, pid, proto) + m.wg.Add(1) + go m.handleOutgoingMessages(pid, proto) } return nil @@ -70,8 +76,15 @@ func (m *Muxer) Start(ctx context.Context) error { // Stop stops muxer activity. func (m *Muxer) Stop() { + if m.cancel == nil { + panic("muxer stopped twice.") + } + // issue cancel, and wipe func. m.cancel() m.cancel = context.CancelFunc(nil) + + // wait for everything to wind down. + m.wg.Wait() } // AddProtocol adds a Protocol with given ProtocolID to the Muxer. @@ -86,7 +99,8 @@ func (m *Muxer) AddProtocol(p Protocol, pid ProtocolID) error { // handleIncoming consumes the messages on the m.Incoming channel and // routes them appropriately (to the protocols). -func (m *Muxer) handleIncomingMessages(ctx context.Context) { +func (m *Muxer) handleIncomingMessages() { + defer m.wg.Done() for { if m == nil { @@ -98,16 +112,16 @@ func (m *Muxer) handleIncomingMessages(ctx context.Context) { if !more { return } - go m.handleIncomingMessage(ctx, msg) + go m.handleIncomingMessage(msg) - case <-ctx.Done(): + case <-m.ctx.Done(): return } } } // handleIncomingMessage routes message to the appropriate protocol. -func (m *Muxer) handleIncomingMessage(ctx context.Context, m1 msg.NetMessage) { +func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) { data, pid, err := unwrapData(m1.Data()) if err != nil { @@ -124,31 +138,33 @@ func (m *Muxer) handleIncomingMessage(ctx context.Context, m1 msg.NetMessage) { select { case proto.GetPipe().Incoming <- m2: - case <-ctx.Done(): - u.PErr("%v\n", ctx.Err()) + case <-m.ctx.Done(): + u.PErr("%v\n", m.ctx.Err()) return } } // handleOutgoingMessages consumes the messages on the proto.Outgoing channel, // wraps them and sends them out. -func (m *Muxer) handleOutgoingMessages(ctx context.Context, pid ProtocolID, proto Protocol) { +func (m *Muxer) handleOutgoingMessages(pid ProtocolID, proto Protocol) { + defer m.wg.Done() + for { select { case msg, more := <-proto.GetPipe().Outgoing: if !more { return } - go m.handleOutgoingMessage(ctx, pid, msg) + go m.handleOutgoingMessage(pid, msg) - case <-ctx.Done(): + case <-m.ctx.Done(): return } } } // handleOutgoingMessage wraps out a message and sends it out the -func (m *Muxer) handleOutgoingMessage(ctx context.Context, pid ProtocolID, m1 msg.NetMessage) { +func (m *Muxer) handleOutgoingMessage(pid ProtocolID, m1 msg.NetMessage) { data, err := wrapData(m1.Data(), pid) if err != nil { u.PErr("muxer serializing error: %v\n", err) @@ -158,7 +174,7 @@ func (m *Muxer) handleOutgoingMessage(ctx context.Context, pid ProtocolID, m1 ms m2 := msg.New(m1.Peer(), data) select { case m.GetPipe().Outgoing <- m2: - case <-ctx.Done(): + case <-m.ctx.Done(): return } } diff --git a/net/mux/mux_test.go b/net/mux/mux_test.go index 6aeeda28c38..17606bf933d 100644 --- a/net/mux/mux_test.go +++ b/net/mux/mux_test.go @@ -229,13 +229,13 @@ func TestStopping(t *testing.T) { mux1.Start(context.Background()) // test outgoing p1 - for _, s := range []string{"foo", "bar", "baz"} { + for _, s := range []string{"foo1", "bar1", "baz1"} { p1.Outgoing <- msg.New(peer1, []byte(s)) testWrappedMsg(t, <-mux1.Outgoing, pid1, []byte(s)) } // test incoming p1 - for _, s := range []string{"foo", "bar", "baz"} { + for _, s := range []string{"foo2", "bar2", "baz2"} { d, err := wrapData([]byte(s), pid1) if err != nil { t.Error(err) @@ -250,17 +250,17 @@ func TestStopping(t *testing.T) { } // test outgoing p1 - for _, s := range []string{"foo", "bar", "baz"} { + for _, s := range []string{"foo3", "bar3", "baz3"} { p1.Outgoing <- msg.New(peer1, []byte(s)) select { - case <-mux1.Outgoing: - t.Error("should not have received anything.") + case m := <-mux1.Outgoing: + t.Errorf("should not have received anything. Got: %v", string(m.Data())) case <-time.After(time.Millisecond): } } // test incoming p1 - for _, s := range []string{"foo", "bar", "baz"} { + for _, s := range []string{"foo4", "bar4", "baz4"} { d, err := wrapData([]byte(s), pid1) if err != nil { t.Error(err) diff --git a/net/net.go b/net/net.go index 67f8254998f..fc341fd7d0e 100644 --- a/net/net.go +++ b/net/net.go @@ -30,7 +30,7 @@ type IpfsNetwork struct { // NewIpfsNetwork is the structure that implements the network interface func NewIpfsNetwork(ctx context.Context, local *peer.Peer, - pmap *mux.ProtocolMap) (*IpfsNetwork, error) { + peers peer.Peerstore, pmap *mux.ProtocolMap) (*IpfsNetwork, error) { ctx, cancel := context.WithCancel(ctx) @@ -47,7 +47,7 @@ func NewIpfsNetwork(ctx context.Context, local *peer.Peer, return nil, err } - in.swarm, err = swarm.NewSwarm(ctx, local) + in.swarm, err = swarm.NewSwarm(ctx, local, peers) if err != nil { cancel() return nil, err diff --git a/net/swarm/conn.go b/net/swarm/conn.go index 03cc92da92c..0713ccf0b8d 100644 --- a/net/swarm/conn.go +++ b/net/swarm/conn.go @@ -76,15 +76,19 @@ func (s *Swarm) connListen(maddr *ma.Multiaddr) error { // Handle getting ID from this peer, handshake, and adding it into the map func (s *Swarm) handleIncomingConn(nconn net.Conn) { - c, err := conn.NewNetConn(nconn) + addr, err := conn.NetConnMultiaddr(nconn) if err != nil { s.errChan <- err return } - //TODO(jbenet) the peer might potentially already be in the global PeerBook. - // maybe use the handshake to populate peer. - c.Peer.AddAddress(c.Addr) + // Construct conn with nil peer for now, because we don't know its ID yet. + // connSetup will figure this out, and pull out / construct the peer. + c, err := conn.NewConn(nil, addr, nconn) + if err != nil { + s.errChan <- err + return + } // Setup the new connection err = s.connSetup(c) @@ -101,7 +105,11 @@ func (s *Swarm) connSetup(c *conn.Conn) error { return errors.New("Tried to start nil connection.") } - u.DOut("Starting connection: %s\n", c.Peer.Key().Pretty()) + if c.Peer != nil { + u.DOut("Starting connection: %s\n", c.Peer.Key().Pretty()) + } else { + u.DOut("Starting connection: [unknown peer]\n") + } if err := s.connSecure(c); err != nil { return fmt.Errorf("Conn securing error: %v", err) @@ -109,6 +117,9 @@ func (s *Swarm) connSetup(c *conn.Conn) error { u.DOut("Secured connection: %s\n", c.Peer.Key().Pretty()) + // add address of connection to Peer. Maybe it should happen in connSecure. + c.Peer.AddAddress(c.Addr) + // add to conns s.connsLock.Lock() if _, ok := s.conns[c.Peer.Key()]; ok { @@ -126,7 +137,7 @@ func (s *Swarm) connSetup(c *conn.Conn) error { // connSecure setups a secure remote connection. func (s *Swarm) connSecure(c *conn.Conn) error { - sp, err := spipe.NewSecurePipe(s.ctx, 10, s.local, c.Peer) + sp, err := spipe.NewSecurePipe(s.ctx, 10, s.local, s.peers) if err != nil { return err } @@ -139,6 +150,13 @@ func (s *Swarm) connSecure(c *conn.Conn) error { return err } + if c.Peer == nil { + c.Peer = sp.RemotePeer() + + } else if c.Peer != sp.RemotePeer() { + panic("peers not being constructed correctly.") + } + c.Secure = sp return nil } diff --git a/net/swarm/swarm.go b/net/swarm/swarm.go index 7ef4ce234c1..df84e5a94ce 100644 --- a/net/swarm/swarm.go +++ b/net/swarm/swarm.go @@ -46,6 +46,9 @@ type Swarm struct { // local is the peer this swarm represents local *peer.Peer + // peers is a collection of peers for swarm to use + peers peer.Peerstore + // Swarm includes a Pipe object. *msg.Pipe @@ -65,11 +68,12 @@ type Swarm struct { } // NewSwarm constructs a Swarm, with a Chan. -func NewSwarm(ctx context.Context, local *peer.Peer) (*Swarm, error) { +func NewSwarm(ctx context.Context, local *peer.Peer, ps peer.Peerstore) (*Swarm, error) { s := &Swarm{ Pipe: msg.NewPipe(10), conns: conn.Map{}, local: local, + peers: ps, errChan: make(chan error, 100), } @@ -112,9 +116,18 @@ func (s *Swarm) Dial(peer *peer.Peer) (*conn.Conn, error) { // check if we already have an open connection first c := s.GetConnection(peer.ID) + if c != nil { + return c, nil + } + + // check if we don't have the peer in Peerstore + err := s.peers.Put(peer) + if err != nil { + return nil, err + } // open connection to peer - c, err := conn.Dial("tcp", peer) + c, err = conn.Dial("tcp", peer) if err != nil { return nil, err } diff --git a/net/swarm/swarm_test.go b/net/swarm/swarm_test.go index 7702e3cb87f..b2747481c9e 100644 --- a/net/swarm/swarm_test.go +++ b/net/swarm/swarm_test.go @@ -71,7 +71,9 @@ func TestSwarm(t *testing.T) { t.Fatal("error setting up peer", err) } - swarm, err := NewSwarm(context.Background(), local) + peerstore := peer.NewPeerstore() + + swarm, err := NewSwarm(context.Background(), local, peerstore) if err != nil { t.Error(err) } diff --git a/peer/peerstore.go b/peer/peerstore.go index 2184d89425e..9c0f28df316 100644 --- a/peer/peerstore.go +++ b/peer/peerstore.go @@ -9,6 +9,10 @@ import ( ds "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/datastore.go" ) +// ErrNotFound signals a peer wasn't found. this is here to avoid having to +// leak the ds abstraction to clients of Peerstore, just for the error. +var ErrNotFound = ds.ErrNotFound + // Peerstore provides a threadsafe collection for peers. type Peerstore interface { Get(ID) (*Peer, error) diff --git a/peer/queue/queue_test.go b/peer/queue/queue_test.go index ff2aafa2a31..8a7d22189fc 100644 --- a/peer/queue/queue_test.go +++ b/peer/queue/queue_test.go @@ -2,6 +2,7 @@ package queue import ( "fmt" + "sync" "testing" "time" @@ -72,18 +73,21 @@ func newPeerTime(t time.Time) *peer.Peer { } func TestSyncQueue(t *testing.T) { - ctx, _ := context.WithTimeout(context.Background(), time.Second*2) + ctx := context.Background() pq := NewXORDistancePQ(u.Key("11140beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a31")) cq := NewChanQueue(ctx, pq) + wg := sync.WaitGroup{} - max := 100000 + max := 10000 consumerN := 10 countsIn := make([]int, consumerN*2) countsOut := make([]int, consumerN) produce := func(p int) { - tick := time.Tick(time.Millisecond) + defer wg.Done() + + tick := time.Tick(time.Microsecond * 100) for i := 0; i < max; i++ { select { case tim := <-tick: @@ -96,10 +100,15 @@ func TestSyncQueue(t *testing.T) { } consume := func(c int) { + defer wg.Done() + for { select { case <-cq.DeqChan: countsOut[c]++ + if countsOut[c] >= max*2 { + return + } case <-ctx.Done(): return } @@ -108,14 +117,13 @@ func TestSyncQueue(t *testing.T) { // make n * 2 producers and n consumers for i := 0; i < consumerN; i++ { + wg.Add(3) go produce(i) go produce(consumerN + i) go consume(i) } - select { - case <-ctx.Done(): - } + wg.Wait() sum := func(ns []int) int { total := 0 @@ -126,6 +134,6 @@ func TestSyncQueue(t *testing.T) { } if sum(countsIn) != sum(countsOut) { - t.Errorf("didnt get all of them out: %d/%d", countsOut, countsIn) + t.Errorf("didnt get all of them out: %d/%d", sum(countsOut), sum(countsIn)) } } diff --git a/routing/dht/dht_test.go b/routing/dht/dht_test.go index 1f41e754aca..1bbc62cdc0f 100644 --- a/routing/dht/dht_test.go +++ b/routing/dht/dht_test.go @@ -31,7 +31,7 @@ func setupDHT(t *testing.T, p *peer.Peer) *IpfsDHT { t.Fatal(err) } - net, err := inet.NewIpfsNetwork(ctx, p, &mux.ProtocolMap{ + net, err := inet.NewIpfsNetwork(ctx, p, peerstore, &mux.ProtocolMap{ mux.ProtocolID_Routing: dhts, }) if err != nil { diff --git a/util/util.go b/util/util.go index 548d777baa0..41f6afede47 100644 --- a/util/util.go +++ b/util/util.go @@ -8,6 +8,7 @@ import ( "path/filepath" "strings" + ds "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/datastore.go" b58 "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-base58" mh "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multihash" ) @@ -26,7 +27,7 @@ var ErrTimeout = errors.New("Error: Call timed out.") var ErrSearchIncomplete = errors.New("Error: Search Incomplete.") // ErrNotFound is returned when a search fails to find anything -var ErrNotFound = errors.New("Error: Not Found.") +var ErrNotFound = ds.ErrNotFound // Key is a string representation of multihash for use with maps. type Key string