From 75c23f229fe3d5cb249573e1f999b48d88ad08a5 Mon Sep 17 00:00:00 2001 From: Dan Sipola Date: Mon, 22 Jan 2018 12:40:22 +0100 Subject: [PATCH] TCP+SNI support arbitrary large Client Hello Use the TLS data to calculate the length of the Client Hello message so that arbitrary large messages are supported. Previoulsy messages larger than 1024 bytes failed due to a fixed size buffer. Note that a Client Hello message must still fit in one TLS record. The code for parsing the TLS data has been moved around a bit to ease unit testing. --- proxy/tcp/sni_proxy.go | 77 +++++++++++++++++++++++++++---- proxy/tcp/sni_proxy_test.go | 58 +++++++++++++++++++++++ proxy/tcp/tls_clienthello.go | 73 ++++------------------------- proxy/tcp/tls_clienthello_test.go | 57 +++++++++++++++++++++++ 4 files changed, 194 insertions(+), 71 deletions(-) create mode 100644 proxy/tcp/sni_proxy_test.go create mode 100644 proxy/tcp/tls_clienthello_test.go diff --git a/proxy/tcp/sni_proxy.go b/proxy/tcp/sni_proxy.go index b75d06705..e1b3f388a 100644 --- a/proxy/tcp/sni_proxy.go +++ b/proxy/tcp/sni_proxy.go @@ -1,6 +1,8 @@ package tcp import ( + "bufio" + "errors" "io" "log" "net" @@ -34,6 +36,48 @@ type SNIProxy struct { Noroute metrics.Counter } +// Create a buffer large enough to hold the client hello message including +// the tls record header and the handshake message header. +// The function requires at least the first 9 bytes of the tls conversation +// in "data". +// nil, error is returned if the data does not follow the +// specification (https://tools.ietf.org/html/rfc5246) or if the client hello +// is fragmented over multiple records. +func createClientHelloBuffer(data []byte) ([]byte, error) { + // TLS record header + // ----------------- + // byte 0: rec type (should be 0x16 == Handshake) + // byte 1-2: version (should be 0x3000 < v < 0x3003) + // byte 3-4: rec len + if len(data) < 9 { + return nil, errors.New("At least 9 bytes required to determine client hello length") + } + + if data[0] != 0x16 { + return nil, errors.New("Not a TLS handshake") + } + + recordLength := int(data[3])<<8 | int(data[4]) + if recordLength <= 0 || recordLength > 16384 { + return nil, errors.New("Invalid TLS record length") + } + + // Handshake record header + // ----------------------- + // byte 5: hs msg type (should be 0x01 == client_hello) + // byte 6-8: hs msg len + if data[5] != 0x01 { + return nil, errors.New("Not a client hello") + } + + handshakeLength := int(data[6])<<16 | int(data[7])<<8 | int(data[8]) + if handshakeLength <= 0 || handshakeLength > recordLength-4 { + return nil, errors.New("Invalid client hello length (fragmentation not implemented)") + } + + return make([]byte, handshakeLength+9), nil //9 for the header bytes +} + func (p *SNIProxy) ServeTCP(in net.Conn) error { defer in.Close() @@ -41,20 +85,37 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error { p.Conn.Inc(1) } - // capture client hello - data := make([]byte, 1024) - n, err := in.Read(data) + tlsReader := bufio.NewReader(in) + data, err := tlsReader.Peek(9) + if err != nil { + log.Print("[DEBUG] tcp+sni: TLS handshake failed (failed to peek data)") + if p.ConnFail != nil { + p.ConnFail.Inc(1) + } + return err + } + + tlsData, err := createClientHelloBuffer(data) + if err != nil { + log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err) + if p.ConnFail != nil { + p.ConnFail.Inc(1) + } + return err + } + + _, err = io.ReadFull(tlsReader, tlsData) if err != nil { + log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err) if p.ConnFail != nil { p.ConnFail.Inc(1) } return err } - data = data[:n] - host, ok := readServerName(data) + host, ok := readServerName(tlsData[5:]) if !ok { - log.Print("[DEBUG] tcp+sni: TLS handshake failed") + log.Print("[DEBUG] tcp+sni: TLS handshake failed (unable to parse client hello)") if p.ConnFail != nil { p.ConnFail.Inc(1) } @@ -88,8 +149,8 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error { } defer out.Close() - // copy client hello - n, err = out.Write(data) + // write the data already read from the connection + n, err := out.Write(tlsData) if err != nil { log.Print("[WARN] tcp+sni: copy client hello failed. ", err) if p.ConnFail != nil { diff --git a/proxy/tcp/sni_proxy_test.go b/proxy/tcp/sni_proxy_test.go new file mode 100644 index 000000000..c41726326 --- /dev/null +++ b/proxy/tcp/sni_proxy_test.go @@ -0,0 +1,58 @@ +package tcp + +import ( + "testing" + + "github.com/fabiolb/fabio/assert" +) + +func TestCreateClientHelloBufferNotTLS(t *testing.T) { + assertEqual := assert.Equal(t) + + testCases := [][]byte{ + // not enough data + {0x16, 0x03, 0x01, 0x00, 0x00, 0x01, 0x00, 0x05}, + + // not tls record + {0x15, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xeb}, + + // too large record + // |---------| + {0x16, 0x03, 0x01, 0x40, 0x01, 0x01, 0x00, 0x01, 0xec}, + + // zero record length + // |----------| + {0x16, 0x03, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0xec}, + + // not client hello + // |----| + {0x16, 0x03, 0x01, 0x01, 0xF4, 0x02, 0x00, 0x01, 0xeb}, + + // bad handshake length + // |----- 0 --------| + {0x16, 0x03, 0x01, 0x00, 0xaa, 0x01, 0x00, 0x00, 0x00}, + + // Fragmentation (handshake larger than record) + // |- 500 ---| |----- 497 ------| + {0x16, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xf1}, + } + + for i := 0; i < len(testCases); i++ { + _, err := createClientHelloBuffer(testCases[i]) + if err == nil { + t.Logf("Case idx %d did not return an error", i) + } + assertEqual(err != nil, true) + } +} + +func TestCreateClientHelloBufferOk(t *testing.T) { + assertEqual := assert.Equal(t) + // Largest possible client hello message + // |- 16384 -| |----- 16380 ----| + data := []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x3f, 0xfc} + buffer, err := createClientHelloBuffer(data) + assertEqual(err, nil) + assertEqual(buffer != nil, true) + assertEqual(len(buffer), 16384+5) // record length + record header +} diff --git a/proxy/tcp/tls_clienthello.go b/proxy/tcp/tls_clienthello.go index 6988c307a..8b535abf5 100644 --- a/proxy/tcp/tls_clienthello.go +++ b/proxy/tcp/tls_clienthello.go @@ -1,73 +1,20 @@ package tcp -// record types -const ( - handshakeRecord = 0x16 - clientHelloType = 0x01 -) - // readServerName returns the server name from a TLS ClientHello message which // has the server_name extension (SNI). ok is set to true if the ClientHello // message was parsed successfully. If the server_name extension was not set -// and empty string is returned as serverName. -func readServerName(data []byte) (serverName string, ok bool) { - if m, ok := readClientHello(data); ok { - return m.serverName, true - } - return "", false -} - -// readClientHello -func readClientHello(data []byte) (m *clientHelloMsg, ok bool) { - if len(data) < 9 { - // println("buf too short") - return nil, false - } - - // TLS record header - // ----------------- - // byte 0: rec type (should be 0x16 == Handshake) - // byte 1-2: version (should be 0x3000 < v < 0x3003) - // byte 3-4: rec len - recType := data[0] - if recType != handshakeRecord { - // println("no handshake ") - return nil, false +// an empty string is returned as serverName. +// clientHelloHandshakeMsg must contain the full client hello handshake +// message including the 4 byte header. +// See: https://www.ietf.org/rfc/rfc5246.txt +func readServerName(clientHelloHandshakeMsg []byte) (serverName string, ok bool) { + m := new(clientHelloMsg) + if !m.unmarshal(clientHelloHandshakeMsg) { + //println("client_hello unmarshal failed") + return "", false } - recLen := int(data[3])<<8 | int(data[4]) - if recLen == 0 || recLen > len(data)-5 { - // println("rec too short") - return nil, false - } - - // Handshake record header - // ----------------------- - // byte 5: hs msg type (should be 0x01 == client_hello) - // byte 6-8: hs msg len - hsType := data[5] - if hsType != clientHelloType { - // println("no client_hello") - return nil, false - } - - hsLen := int(data[6])<<16 | int(data[7])<<8 | int(data[8]) - if hsLen == 0 || hsLen > len(data)-9 { - // println("handshake rec too short") - return nil, false - } - - // byte 9- : client hello msg - // - // m.unmarshal parses the entire handshake message and - // not just the client hello. Therefore, we need to pass - // data from byte 5 instead of byte 9. (see comment below) - m = new(clientHelloMsg) - if !m.unmarshal(data[5:]) { - // println("client_hello unmarshal failed") - return nil, false - } - return m, true + return m.serverName, true } // The code below is a verbatim copy from go1.7/src/crypto/tls/handshake_messages.go diff --git a/proxy/tcp/tls_clienthello_test.go b/proxy/tcp/tls_clienthello_test.go new file mode 100644 index 000000000..9457bfba1 --- /dev/null +++ b/proxy/tcp/tls_clienthello_test.go @@ -0,0 +1,57 @@ +package tcp + +import ( + "encoding/hex" + "testing" + + "github.com/fabiolb/fabio/assert" +) + +func TestReadServerNameBadData(t *testing.T) { + assertEqual := assert.Equal(t) + clientHelloMsg := []byte{0x16, 0x03, 0x01, 0x45, 0x03, 0x01, 0x2, 0x01} + serverName, ok := readServerName(clientHelloMsg) + assertEqual(serverName, "") + assertEqual(ok, false) +} + +func TestReadServerNameNoExtension(t *testing.T) { + assertEqual := assert.Equal(t) + // Client hello from: + // openssl s_client -connect google.com:443 + clientHelloMsg, _ := hex.DecodeString( + "0100013503036dfb09de7b16503dd1bb304dcbe54079913b65abf53de997f73b26c99e" + + "67ba28000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f006b006a00" + + "390038ff8500c400c3008800870081c032c02ec02ac026c00fc005009d003d0035" + + "00c00084c02fc02bc027c023c013c00900a2009e006700400033003200be00bd00" + + "450044c031c02dc029c025c00ec004009c003c002f00ba0041c011c007c00cc002" + + "00050004c012c00800160013c00dc003000a00150012000900ff01000074000b00" + + "0403000102000a003a0038000e000d0019001c000b000c001b00180009000a001a" + + "00160017000800060007001400150004000500120013000100020003000f001000" + + "1100230000000d00260024060106020603efef050105020503040104020403eeee" + + "eded030103020303020102020203") + serverName, ok := readServerName(clientHelloMsg) + assertEqual(serverName, "") + assertEqual(ok, true) +} + +func TestReadServerNameOk(t *testing.T) { + assertEqual := assert.Equal(t) + // Client hello from: + // openssl s_client -connect google.com:443 -servername google.com + clientHelloMsg, _ := hex.DecodeString( + "0100014803032657cacce41598fa82e5b75061050bc31c5affdba106b8e7431852" + + "24af0fa1aa000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f00" + + "6b006a00390038ff8500c400c3008800870081c032c02ec02ac026c00fc005" + + "009d003d003500c00084c02fc02bc027c023c013c00900a2009e0067004000" + + "33003200be00bd00450044c031c02dc029c025c00ec004009c003c002f00ba" + + "0041c011c007c00cc00200050004c012c00800160013c00dc003000a001500" + + "12000900ff010000870000000f000d00000a676f6f676c652e636f6d000b00" + + "0403000102000a003a0038000e000d0019001c000b000c001b00180009000a" + + "001a0016001700080006000700140015000400050012001300010002000300" + + "0f0010001100230000000d00260024060106020603efef0501050205030401" + + "04020403eeeeeded030103020303020102020203") + serverName, ok := readServerName(clientHelloMsg) + assertEqual(serverName, "google.com") + assertEqual(ok, true) +}