diff --git a/proxy/tcp/sni_proxy.go b/proxy/tcp/sni_proxy.go index b75d06705..e1b3f388a 100644 --- a/proxy/tcp/sni_proxy.go +++ b/proxy/tcp/sni_proxy.go @@ -1,6 +1,8 @@ package tcp import ( + "bufio" + "errors" "io" "log" "net" @@ -34,6 +36,48 @@ type SNIProxy struct { Noroute metrics.Counter } +// Create a buffer large enough to hold the client hello message including +// the tls record header and the handshake message header. +// The function requires at least the first 9 bytes of the tls conversation +// in "data". +// nil, error is returned if the data does not follow the +// specification (https://tools.ietf.org/html/rfc5246) or if the client hello +// is fragmented over multiple records. +func createClientHelloBuffer(data []byte) ([]byte, error) { + // TLS record header + // ----------------- + // byte 0: rec type (should be 0x16 == Handshake) + // byte 1-2: version (should be 0x3000 < v < 0x3003) + // byte 3-4: rec len + if len(data) < 9 { + return nil, errors.New("At least 9 bytes required to determine client hello length") + } + + if data[0] != 0x16 { + return nil, errors.New("Not a TLS handshake") + } + + recordLength := int(data[3])<<8 | int(data[4]) + if recordLength <= 0 || recordLength > 16384 { + return nil, errors.New("Invalid TLS record length") + } + + // Handshake record header + // ----------------------- + // byte 5: hs msg type (should be 0x01 == client_hello) + // byte 6-8: hs msg len + if data[5] != 0x01 { + return nil, errors.New("Not a client hello") + } + + handshakeLength := int(data[6])<<16 | int(data[7])<<8 | int(data[8]) + if handshakeLength <= 0 || handshakeLength > recordLength-4 { + return nil, errors.New("Invalid client hello length (fragmentation not implemented)") + } + + return make([]byte, handshakeLength+9), nil //9 for the header bytes +} + func (p *SNIProxy) ServeTCP(in net.Conn) error { defer in.Close() @@ -41,20 +85,37 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error { p.Conn.Inc(1) } - // capture client hello - data := make([]byte, 1024) - n, err := in.Read(data) + tlsReader := bufio.NewReader(in) + data, err := tlsReader.Peek(9) + if err != nil { + log.Print("[DEBUG] tcp+sni: TLS handshake failed (failed to peek data)") + if p.ConnFail != nil { + p.ConnFail.Inc(1) + } + return err + } + + tlsData, err := createClientHelloBuffer(data) + if err != nil { + log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err) + if p.ConnFail != nil { + p.ConnFail.Inc(1) + } + return err + } + + _, err = io.ReadFull(tlsReader, tlsData) if err != nil { + log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err) if p.ConnFail != nil { p.ConnFail.Inc(1) } return err } - data = data[:n] - host, ok := readServerName(data) + host, ok := readServerName(tlsData[5:]) if !ok { - log.Print("[DEBUG] tcp+sni: TLS handshake failed") + log.Print("[DEBUG] tcp+sni: TLS handshake failed (unable to parse client hello)") if p.ConnFail != nil { p.ConnFail.Inc(1) } @@ -88,8 +149,8 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error { } defer out.Close() - // copy client hello - n, err = out.Write(data) + // write the data already read from the connection + n, err := out.Write(tlsData) if err != nil { log.Print("[WARN] tcp+sni: copy client hello failed. ", err) if p.ConnFail != nil { diff --git a/proxy/tcp/sni_proxy_test.go b/proxy/tcp/sni_proxy_test.go new file mode 100644 index 000000000..c41726326 --- /dev/null +++ b/proxy/tcp/sni_proxy_test.go @@ -0,0 +1,58 @@ +package tcp + +import ( + "testing" + + "github.com/fabiolb/fabio/assert" +) + +func TestCreateClientHelloBufferNotTLS(t *testing.T) { + assertEqual := assert.Equal(t) + + testCases := [][]byte{ + // not enough data + {0x16, 0x03, 0x01, 0x00, 0x00, 0x01, 0x00, 0x05}, + + // not tls record + {0x15, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xeb}, + + // too large record + // |---------| + {0x16, 0x03, 0x01, 0x40, 0x01, 0x01, 0x00, 0x01, 0xec}, + + // zero record length + // |----------| + {0x16, 0x03, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0xec}, + + // not client hello + // |----| + {0x16, 0x03, 0x01, 0x01, 0xF4, 0x02, 0x00, 0x01, 0xeb}, + + // bad handshake length + // |----- 0 --------| + {0x16, 0x03, 0x01, 0x00, 0xaa, 0x01, 0x00, 0x00, 0x00}, + + // Fragmentation (handshake larger than record) + // |- 500 ---| |----- 497 ------| + {0x16, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xf1}, + } + + for i := 0; i < len(testCases); i++ { + _, err := createClientHelloBuffer(testCases[i]) + if err == nil { + t.Logf("Case idx %d did not return an error", i) + } + assertEqual(err != nil, true) + } +} + +func TestCreateClientHelloBufferOk(t *testing.T) { + assertEqual := assert.Equal(t) + // Largest possible client hello message + // |- 16384 -| |----- 16380 ----| + data := []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x3f, 0xfc} + buffer, err := createClientHelloBuffer(data) + assertEqual(err, nil) + assertEqual(buffer != nil, true) + assertEqual(len(buffer), 16384+5) // record length + record header +} diff --git a/proxy/tcp/tls_clienthello.go b/proxy/tcp/tls_clienthello.go index 6988c307a..8b535abf5 100644 --- a/proxy/tcp/tls_clienthello.go +++ b/proxy/tcp/tls_clienthello.go @@ -1,73 +1,20 @@ package tcp -// record types -const ( - handshakeRecord = 0x16 - clientHelloType = 0x01 -) - // readServerName returns the server name from a TLS ClientHello message which // has the server_name extension (SNI). ok is set to true if the ClientHello // message was parsed successfully. If the server_name extension was not set -// and empty string is returned as serverName. -func readServerName(data []byte) (serverName string, ok bool) { - if m, ok := readClientHello(data); ok { - return m.serverName, true - } - return "", false -} - -// readClientHello -func readClientHello(data []byte) (m *clientHelloMsg, ok bool) { - if len(data) < 9 { - // println("buf too short") - return nil, false - } - - // TLS record header - // ----------------- - // byte 0: rec type (should be 0x16 == Handshake) - // byte 1-2: version (should be 0x3000 < v < 0x3003) - // byte 3-4: rec len - recType := data[0] - if recType != handshakeRecord { - // println("no handshake ") - return nil, false +// an empty string is returned as serverName. +// clientHelloHandshakeMsg must contain the full client hello handshake +// message including the 4 byte header. +// See: https://www.ietf.org/rfc/rfc5246.txt +func readServerName(clientHelloHandshakeMsg []byte) (serverName string, ok bool) { + m := new(clientHelloMsg) + if !m.unmarshal(clientHelloHandshakeMsg) { + //println("client_hello unmarshal failed") + return "", false } - recLen := int(data[3])<<8 | int(data[4]) - if recLen == 0 || recLen > len(data)-5 { - // println("rec too short") - return nil, false - } - - // Handshake record header - // ----------------------- - // byte 5: hs msg type (should be 0x01 == client_hello) - // byte 6-8: hs msg len - hsType := data[5] - if hsType != clientHelloType { - // println("no client_hello") - return nil, false - } - - hsLen := int(data[6])<<16 | int(data[7])<<8 | int(data[8]) - if hsLen == 0 || hsLen > len(data)-9 { - // println("handshake rec too short") - return nil, false - } - - // byte 9- : client hello msg - // - // m.unmarshal parses the entire handshake message and - // not just the client hello. Therefore, we need to pass - // data from byte 5 instead of byte 9. (see comment below) - m = new(clientHelloMsg) - if !m.unmarshal(data[5:]) { - // println("client_hello unmarshal failed") - return nil, false - } - return m, true + return m.serverName, true } // The code below is a verbatim copy from go1.7/src/crypto/tls/handshake_messages.go diff --git a/proxy/tcp/tls_clienthello_test.go b/proxy/tcp/tls_clienthello_test.go new file mode 100644 index 000000000..9457bfba1 --- /dev/null +++ b/proxy/tcp/tls_clienthello_test.go @@ -0,0 +1,57 @@ +package tcp + +import ( + "encoding/hex" + "testing" + + "github.com/fabiolb/fabio/assert" +) + +func TestReadServerNameBadData(t *testing.T) { + assertEqual := assert.Equal(t) + clientHelloMsg := []byte{0x16, 0x03, 0x01, 0x45, 0x03, 0x01, 0x2, 0x01} + serverName, ok := readServerName(clientHelloMsg) + assertEqual(serverName, "") + assertEqual(ok, false) +} + +func TestReadServerNameNoExtension(t *testing.T) { + assertEqual := assert.Equal(t) + // Client hello from: + // openssl s_client -connect google.com:443 + clientHelloMsg, _ := hex.DecodeString( + "0100013503036dfb09de7b16503dd1bb304dcbe54079913b65abf53de997f73b26c99e" + + "67ba28000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f006b006a00" + + "390038ff8500c400c3008800870081c032c02ec02ac026c00fc005009d003d0035" + + "00c00084c02fc02bc027c023c013c00900a2009e006700400033003200be00bd00" + + "450044c031c02dc029c025c00ec004009c003c002f00ba0041c011c007c00cc002" + + "00050004c012c00800160013c00dc003000a00150012000900ff01000074000b00" + + "0403000102000a003a0038000e000d0019001c000b000c001b00180009000a001a" + + "00160017000800060007001400150004000500120013000100020003000f001000" + + "1100230000000d00260024060106020603efef050105020503040104020403eeee" + + "eded030103020303020102020203") + serverName, ok := readServerName(clientHelloMsg) + assertEqual(serverName, "") + assertEqual(ok, true) +} + +func TestReadServerNameOk(t *testing.T) { + assertEqual := assert.Equal(t) + // Client hello from: + // openssl s_client -connect google.com:443 -servername google.com + clientHelloMsg, _ := hex.DecodeString( + "0100014803032657cacce41598fa82e5b75061050bc31c5affdba106b8e7431852" + + "24af0fa1aa000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f00" + + "6b006a00390038ff8500c400c3008800870081c032c02ec02ac026c00fc005" + + "009d003d003500c00084c02fc02bc027c023c013c00900a2009e0067004000" + + "33003200be00bd00450044c031c02dc029c025c00ec004009c003c002f00ba" + + "0041c011c007c00cc00200050004c012c00800160013c00dc003000a001500" + + "12000900ff010000870000000f000d00000a676f6f676c652e636f6d000b00" + + "0403000102000a003a0038000e000d0019001c000b000c001b00180009000a" + + "001a0016001700080006000700140015000400050012001300010002000300" + + "0f0010001100230000000d00260024060106020603efef0501050205030401" + + "04020403eeeeeded030103020303020102020203") + serverName, ok := readServerName(clientHelloMsg) + assertEqual(serverName, "google.com") + assertEqual(ok, true) +}