diff --git a/ssh/handshake.go b/ssh/handshake.go index e68e058615..93c23d16c8 100644 --- a/ssh/handshake.go +++ b/ssh/handshake.go @@ -65,8 +65,9 @@ type handshakeTransport struct { pendingPackets [][]byte // Used when a key exchange is in progress. // If the read loop wants to schedule a kex, it pings this - // channel, and the write loop will send out a kex message. - requestKex chan struct{} + // channel, and the write loop will send out a kex + // message. The boolean is whether this is the first request or not. + requestKex chan bool // If the other side requests or confirms a kex, its kexInit // packet is sent here for the write loop to find it. @@ -96,11 +97,14 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion: serverVersion, clientVersion: clientVersion, incoming: make(chan []byte, chanSize), - requestKex: make(chan struct{}, 1), + requestKex: make(chan bool, 1), startKex: make(chan *pendingKex, 1), config: config, } + + // We always start with a mandatory key exchange. + t.requestKex <- true return t } @@ -174,12 +178,6 @@ func (t *handshakeTransport) readPacket() ([]byte, error) { } func (t *handshakeTransport) readLoop() { - // We always start with the mandatory key exchange. We use - // the channel for simplicity, and this works if we can rely - // on the SSH package itself not doing anything else before - // waitSession has completed. - t.requestKeyExchange() - first := true for { p, err := t.readOnePacket(first) @@ -227,14 +225,15 @@ func (t *handshakeTransport) recordWriteError(err error) { func (t *handshakeTransport) requestKeyExchange() { select { - case t.requestKex <- struct{}{}: + case t.requestKex <- false: default: // something already requested a kex, so do nothing. } - } func (t *handshakeTransport) kexLoop() { + firstSent := false + write: for t.getWriteError() == nil { var request *pendingKex @@ -247,7 +246,18 @@ write: if !ok { break write } - case <-t.requestKex: + case requestFirst := <-t.requestKex: + // For the first key exchange, both + // sides will initiate a key exchange, + // and both channels will fire. To + // avoid doing two key exchanges in a + // row, ignore our own request for an + // initial kex if we have already sent + // it out. + if firstSent && requestFirst { + + continue + } } if !sent { @@ -255,6 +265,7 @@ write: t.recordWriteError(err) break } + firstSent = true sent = true } }