diff --git a/proxy/tcp/sni_proxy.go b/proxy/tcp/sni_proxy.go index 37ec6b7e6..d9920876c 100644 --- a/proxy/tcp/sni_proxy.go +++ b/proxy/tcp/sni_proxy.go @@ -1,6 +1,7 @@ package tcp import ( + "bufio" "io" "log" "net" @@ -41,20 +42,40 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error { p.Conn.Inc(1) } - // capture client hello - data := make([]byte, 1024) - n, err := in.Read(data) + tlsReader := bufio.NewReader(in) + tlsHeaders, err := tlsReader.Peek(9) if err != nil { + log.Print("[DEBUG] tcp+sni: TLS handshake failed (failed to peek data)") if p.ConnFail != nil { p.ConnFail.Inc(1) } return err } - data = data[:n] - host, ok := readServerName(data) + bufferSize, err := clientHelloBufferSize(tlsHeaders) + if err != nil { + log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err) + if p.ConnFail != nil { + p.ConnFail.Inc(1) + } + return err + } + + data := make([]byte, bufferSize) + _, err = io.ReadFull(tlsReader, data) + if err != nil { + log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err) + if p.ConnFail != nil { + p.ConnFail.Inc(1) + } + return err + } + + // readServerName wants only the handshake message so ignore the first + // 5 bytes which is the TLS record header + host, ok := readServerName(data[5:]) if !ok { - log.Print("[DEBUG] tcp+sni: TLS handshake failed") + log.Print("[DEBUG] tcp+sni: TLS handshake failed (unable to parse client hello)") if p.ConnFail != nil { p.ConnFail.Inc(1) } @@ -92,8 +113,8 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error { } defer out.Close() - // copy client hello - n, err = out.Write(data) + // write the data already read from the connection + n, err := out.Write(data) if err != nil { log.Print("[WARN] tcp+sni: copy client hello failed. ", err) if p.ConnFail != nil { diff --git a/proxy/tcp/tls_clienthello.go b/proxy/tcp/tls_clienthello.go index 6988c307a..69f9fb468 100644 --- a/proxy/tcp/tls_clienthello.go +++ b/proxy/tcp/tls_clienthello.go @@ -1,73 +1,65 @@ package tcp -// record types -const ( - handshakeRecord = 0x16 - clientHelloType = 0x01 -) - -// readServerName returns the server name from a TLS ClientHello message which -// has the server_name extension (SNI). ok is set to true if the ClientHello -// message was parsed successfully. If the server_name extension was not set -// and empty string is returned as serverName. -func readServerName(data []byte) (serverName string, ok bool) { - if m, ok := readClientHello(data); ok { - return m.serverName, true - } - return "", false -} - -// readClientHello -func readClientHello(data []byte) (m *clientHelloMsg, ok bool) { - if len(data) < 9 { - // println("buf too short") - return nil, false - } +import "errors" +// Determines the required size of a buffer large enough to hold +// a client hello message including the tls record header and the +// handshake message header. +// The function requires at least the first 9 bytes of the tls conversation +// in "data". +// An error is returned if the data does not follow the +// specification (https://tools.ietf.org/html/rfc5246) or if the client hello +// is fragmented over multiple records. +func clientHelloBufferSize(data []byte) (int, error) { // TLS record header // ----------------- // byte 0: rec type (should be 0x16 == Handshake) // byte 1-2: version (should be 0x3000 < v < 0x3003) // byte 3-4: rec len - recType := data[0] - if recType != handshakeRecord { - // println("no handshake ") - return nil, false + if len(data) < 9 { + return 0, errors.New("At least 9 bytes required to determine client hello length") } - recLen := int(data[3])<<8 | int(data[4]) - if recLen == 0 || recLen > len(data)-5 { - // println("rec too short") - return nil, false + if data[0] != 0x16 { + return 0, errors.New("Not a TLS handshake") + } + + recordLength := int(data[3])<<8 | int(data[4]) + if recordLength <= 0 || recordLength > 16384 { + return 0, errors.New("Invalid TLS record length") } // Handshake record header // ----------------------- // byte 5: hs msg type (should be 0x01 == client_hello) // byte 6-8: hs msg len - hsType := data[5] - if hsType != clientHelloType { - // println("no client_hello") - return nil, false + if data[5] != 0x01 { + return 0, errors.New("Not a client hello") } - hsLen := int(data[6])<<16 | int(data[7])<<8 | int(data[8]) - if hsLen == 0 || hsLen > len(data)-9 { - // println("handshake rec too short") - return nil, false + handshakeLength := int(data[6])<<16 | int(data[7])<<8 | int(data[8]) + if handshakeLength <= 0 || handshakeLength > recordLength-4 { + return 0, errors.New("Invalid client hello length (fragmentation not implemented)") } - // byte 9- : client hello msg - // - // m.unmarshal parses the entire handshake message and - // not just the client hello. Therefore, we need to pass - // data from byte 5 instead of byte 9. (see comment below) - m = new(clientHelloMsg) - if !m.unmarshal(data[5:]) { - // println("client_hello unmarshal failed") - return nil, false + return handshakeLength + 9, nil //9 for the header bytes +} + +// readServerName returns the server name from a TLS ClientHello message which +// has the server_name extension (SNI). ok is set to true if the ClientHello +// message was parsed successfully. If the server_name extension was not set +// an empty string is returned as serverName. +// clientHelloHandshakeMsg must contain the full client hello handshake +// message including the 4 byte header. +// See: https://www.ietf.org/rfc/rfc5246.txt +func readServerName(clientHelloHandshakeMsg []byte) (serverName string, ok bool) { + m := new(clientHelloMsg) + if !m.unmarshal(clientHelloHandshakeMsg) { + //println("client_hello unmarshal failed") + return "", false } - return m, true + + return m.serverName, true } // The code below is a verbatim copy from go1.7/src/crypto/tls/handshake_messages.go diff --git a/proxy/tcp/tls_clienthello_test.go b/proxy/tcp/tls_clienthello_test.go new file mode 100644 index 000000000..9c947614d --- /dev/null +++ b/proxy/tcp/tls_clienthello_test.go @@ -0,0 +1,162 @@ +package tcp + +import ( + "encoding/hex" + "testing" +) + +func TestClientHelloBufferSize(t *testing.T) { + tests := []struct { + name string + data []byte + size int + fail bool + }{ + { + name: "valid data", + // Largest possible client hello message + // |- 16384 -| |----- 16380 ----| + data: []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x3f, 0xfc}, + size: 16384 + 5, // max record length + record header + fail: false, + }, + { + name: "not enough data", + data: []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x3f}, + size: 0, + fail: true, + }, + { + name: "not a TLS record", + data: []byte{0x15, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xeb}, + size: 0, + fail: true, + }, + + { + name: "TLS record too large", + // | max + 1 | + data: []byte{0x16, 0x03, 0x01, 0x40, 0x01, 0x01, 0x00, 0x3f, 0xfc}, + size: 0, + fail: true, + }, + + { + name: "TLS record length zero", + // |----------| + data: []byte{0x16, 0x03, 0x01, 0x00, 0x00, 0x01, 0x00, 0x3f, 0xfc}, + size: 0, + fail: true, + }, + + { + name: "Not a client hello", + // |----| + data: []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x02, 0x00, 0x3f, 0xfc}, + size: 0, + fail: true, + }, + + { + name: "Invalid handshake message record length", + // |----- 0 --------| + data: []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x00, 0x00}, + size: 0, + fail: true, + }, + + { + name: "Fragmentation (handshake message larger than record)", + // |- 500 ---| |----- 497 ------| + data: []byte{0x16, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xf1}, + size: 0, + fail: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := clientHelloBufferSize(tt.data) + + if tt.fail && err == nil { + t.Fatal("expected error, got nil") + } else if !tt.fail && err != nil { + t.Fatalf("expected error to be nil, got %s", err) + } + + if want := tt.size; got != want { + t.Fatalf("want size %d, got %d", want, got) + } + }) + } +} + +func TestReadServerName(t *testing.T) { + tests := []struct { + name string + servername string + ok bool + data string //Hex string, decoded by test + }{ + { + // Client hello from: + // openssl s_client -connect google.com:443 -servername google.com + name: "valid client hello with server name", + servername: "google.com", + ok: true, + data: "0100014803032657cacce41598fa82e5b75061050bc31c5affdba106b8e7431852" + + "24af0fa1aa000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f00" + + "6b006a00390038ff8500c400c3008800870081c032c02ec02ac026c00fc005" + + "009d003d003500c00084c02fc02bc027c023c013c00900a2009e0067004000" + + "33003200be00bd00450044c031c02dc029c025c00ec004009c003c002f00ba" + + "0041c011c007c00cc00200050004c012c00800160013c00dc003000a001500" + + "12000900ff010000870000000f000d00000a676f6f676c652e636f6d000b00" + + "0403000102000a003a0038000e000d0019001c000b000c001b00180009000a" + + "001a0016001700080006000700140015000400050012001300010002000300" + + "0f0010001100230000000d00260024060106020603efef0501050205030401" + + "04020403eeeeeded030103020303020102020203", + }, + { + // Client hello from: + // openssl s_client -connect google.com:443 + name: "valid client hello but no server name extension", + servername: "", + ok: true, + data: "0100013503036dfb09de7b16503dd1bb304dcbe54079913b65abf53de997f73b26c99e" + + "67ba28000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f006b006a00" + + "390038ff8500c400c3008800870081c032c02ec02ac026c00fc005009d003d0035" + + "00c00084c02fc02bc027c023c013c00900a2009e006700400033003200be00bd00" + + "450044c031c02dc029c025c00ec004009c003c002f00ba0041c011c007c00cc002" + + "00050004c012c00800160013c00dc003000a00150012000900ff01000074000b00" + + "0403000102000a003a0038000e000d0019001c000b000c001b00180009000a001a" + + "00160017000800060007001400150004000500120013000100020003000f001000" + + "1100230000000d00260024060106020603efef050105020503040104020403eeee" + + "eded030103020303020102020203", + }, + { + name: "invalid client hello", + servername: "", + ok: false, + data: "0100014c5768656e2070656f706c652073617920746f206d653a20776f756c6420796f" + + "75207261746865722062652074686f75676874206f6620617320612066756e6e79" + + "206d616e206f72206120677265617420626f73733f204d7920616e737765722773" + + "20616c77617973207468652073616d652c20746f206d652c207468657927726520" + + "6e6f74206d757475616c6c79206578636c75736976652e2d204461766964204272" + + "656e74", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientHelloMsg, _ := hex.DecodeString(tt.data) + servername, ok := readServerName(clientHelloMsg) + if got, want := servername, tt.servername; got != want { + t.Fatalf("%s: got servername \"%s\" want \"%s\"", tt.name, got, want) + } + + if got, want := ok, tt.ok; got != want { + t.Fatalf("%s: got ok %t want %t", tt.name, got, want) + } + }) + } +}