Skip to content

Commit

Permalink
TCP+SNI support arbitrary large Client Hello
Browse files Browse the repository at this point in the history
Use the TLS data to calculate the length of the Client Hello
message so that arbitrary large messages are supported. Previoulsy
messages larger than 1024 bytes failed due to a fixed size buffer.
Note that a Client Hello message must still fit in one TLS record.

The code for parsing the TLS data has been moved around
a bit to ease unit testing.
  • Loading branch information
DanSipola committed Jan 23, 2018
1 parent 2b7594a commit 75c23f2
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 71 deletions.
77 changes: 69 additions & 8 deletions proxy/tcp/sni_proxy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package tcp

import (
"bufio"
"errors"
"io"
"log"
"net"
Expand Down Expand Up @@ -34,27 +36,86 @@ type SNIProxy struct {
Noroute metrics.Counter
}

// Create a buffer large enough to hold the client hello message including
// the tls record header and the handshake message header.
// The function requires at least the first 9 bytes of the tls conversation
// in "data".
// nil, error is returned if the data does not follow the
// specification (https://tools.ietf.org/html/rfc5246) or if the client hello
// is fragmented over multiple records.
func createClientHelloBuffer(data []byte) ([]byte, error) {
// TLS record header
// -----------------
// byte 0: rec type (should be 0x16 == Handshake)
// byte 1-2: version (should be 0x3000 < v < 0x3003)
// byte 3-4: rec len
if len(data) < 9 {
return nil, errors.New("At least 9 bytes required to determine client hello length")
}

if data[0] != 0x16 {
return nil, errors.New("Not a TLS handshake")
}

recordLength := int(data[3])<<8 | int(data[4])
if recordLength <= 0 || recordLength > 16384 {
return nil, errors.New("Invalid TLS record length")
}

// Handshake record header
// -----------------------
// byte 5: hs msg type (should be 0x01 == client_hello)
// byte 6-8: hs msg len
if data[5] != 0x01 {
return nil, errors.New("Not a client hello")
}

handshakeLength := int(data[6])<<16 | int(data[7])<<8 | int(data[8])
if handshakeLength <= 0 || handshakeLength > recordLength-4 {
return nil, errors.New("Invalid client hello length (fragmentation not implemented)")
}

return make([]byte, handshakeLength+9), nil //9 for the header bytes
}

func (p *SNIProxy) ServeTCP(in net.Conn) error {
defer in.Close()

if p.Conn != nil {
p.Conn.Inc(1)
}

// capture client hello
data := make([]byte, 1024)
n, err := in.Read(data)
tlsReader := bufio.NewReader(in)
data, err := tlsReader.Peek(9)
if err != nil {
log.Print("[DEBUG] tcp+sni: TLS handshake failed (failed to peek data)")
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
return err
}

tlsData, err := createClientHelloBuffer(data)
if err != nil {
log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err)
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
return err
}

_, err = io.ReadFull(tlsReader, tlsData)
if err != nil {
log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err)
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
return err
}
data = data[:n]

host, ok := readServerName(data)
host, ok := readServerName(tlsData[5:])
if !ok {
log.Print("[DEBUG] tcp+sni: TLS handshake failed")
log.Print("[DEBUG] tcp+sni: TLS handshake failed (unable to parse client hello)")
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
Expand Down Expand Up @@ -88,8 +149,8 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error {
}
defer out.Close()

// copy client hello
n, err = out.Write(data)
// write the data already read from the connection
n, err := out.Write(tlsData)
if err != nil {
log.Print("[WARN] tcp+sni: copy client hello failed. ", err)
if p.ConnFail != nil {
Expand Down
58 changes: 58 additions & 0 deletions proxy/tcp/sni_proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package tcp

import (
"testing"

"github.com/fabiolb/fabio/assert"
)

func TestCreateClientHelloBufferNotTLS(t *testing.T) {
assertEqual := assert.Equal(t)

testCases := [][]byte{
// not enough data
{0x16, 0x03, 0x01, 0x00, 0x00, 0x01, 0x00, 0x05},

// not tls record
{0x15, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xeb},

// too large record
// |---------|
{0x16, 0x03, 0x01, 0x40, 0x01, 0x01, 0x00, 0x01, 0xec},

// zero record length
// |----------|
{0x16, 0x03, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0xec},

// not client hello
// |----|
{0x16, 0x03, 0x01, 0x01, 0xF4, 0x02, 0x00, 0x01, 0xeb},

// bad handshake length
// |----- 0 --------|
{0x16, 0x03, 0x01, 0x00, 0xaa, 0x01, 0x00, 0x00, 0x00},

// Fragmentation (handshake larger than record)
// |- 500 ---| |----- 497 ------|
{0x16, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xf1},
}

for i := 0; i < len(testCases); i++ {
_, err := createClientHelloBuffer(testCases[i])
if err == nil {
t.Logf("Case idx %d did not return an error", i)
}
assertEqual(err != nil, true)
}
}

func TestCreateClientHelloBufferOk(t *testing.T) {
assertEqual := assert.Equal(t)
// Largest possible client hello message
// |- 16384 -| |----- 16380 ----|
data := []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x3f, 0xfc}
buffer, err := createClientHelloBuffer(data)
assertEqual(err, nil)
assertEqual(buffer != nil, true)
assertEqual(len(buffer), 16384+5) // record length + record header
}
73 changes: 10 additions & 63 deletions proxy/tcp/tls_clienthello.go
Original file line number Diff line number Diff line change
@@ -1,73 +1,20 @@
package tcp

// record types
const (
handshakeRecord = 0x16
clientHelloType = 0x01
)

// readServerName returns the server name from a TLS ClientHello message which
// has the server_name extension (SNI). ok is set to true if the ClientHello
// message was parsed successfully. If the server_name extension was not set
// and empty string is returned as serverName.
func readServerName(data []byte) (serverName string, ok bool) {
if m, ok := readClientHello(data); ok {
return m.serverName, true
}
return "", false
}

// readClientHello
func readClientHello(data []byte) (m *clientHelloMsg, ok bool) {
if len(data) < 9 {
// println("buf too short")
return nil, false
}

// TLS record header
// -----------------
// byte 0: rec type (should be 0x16 == Handshake)
// byte 1-2: version (should be 0x3000 < v < 0x3003)
// byte 3-4: rec len
recType := data[0]
if recType != handshakeRecord {
// println("no handshake ")
return nil, false
// an empty string is returned as serverName.
// clientHelloHandshakeMsg must contain the full client hello handshake
// message including the 4 byte header.
// See: https://www.ietf.org/rfc/rfc5246.txt
func readServerName(clientHelloHandshakeMsg []byte) (serverName string, ok bool) {
m := new(clientHelloMsg)
if !m.unmarshal(clientHelloHandshakeMsg) {
//println("client_hello unmarshal failed")
return "", false
}

recLen := int(data[3])<<8 | int(data[4])
if recLen == 0 || recLen > len(data)-5 {
// println("rec too short")
return nil, false
}

// Handshake record header
// -----------------------
// byte 5: hs msg type (should be 0x01 == client_hello)
// byte 6-8: hs msg len
hsType := data[5]
if hsType != clientHelloType {
// println("no client_hello")
return nil, false
}

hsLen := int(data[6])<<16 | int(data[7])<<8 | int(data[8])
if hsLen == 0 || hsLen > len(data)-9 {
// println("handshake rec too short")
return nil, false
}

// byte 9- : client hello msg
//
// m.unmarshal parses the entire handshake message and
// not just the client hello. Therefore, we need to pass
// data from byte 5 instead of byte 9. (see comment below)
m = new(clientHelloMsg)
if !m.unmarshal(data[5:]) {
// println("client_hello unmarshal failed")
return nil, false
}
return m, true
return m.serverName, true
}

// The code below is a verbatim copy from go1.7/src/crypto/tls/handshake_messages.go
Expand Down
57 changes: 57 additions & 0 deletions proxy/tcp/tls_clienthello_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package tcp

import (
"encoding/hex"
"testing"

"github.com/fabiolb/fabio/assert"
)

func TestReadServerNameBadData(t *testing.T) {
assertEqual := assert.Equal(t)
clientHelloMsg := []byte{0x16, 0x03, 0x01, 0x45, 0x03, 0x01, 0x2, 0x01}
serverName, ok := readServerName(clientHelloMsg)
assertEqual(serverName, "")
assertEqual(ok, false)
}

func TestReadServerNameNoExtension(t *testing.T) {
assertEqual := assert.Equal(t)
// Client hello from:
// openssl s_client -connect google.com:443
clientHelloMsg, _ := hex.DecodeString(
"0100013503036dfb09de7b16503dd1bb304dcbe54079913b65abf53de997f73b26c99e" +
"67ba28000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f006b006a00" +
"390038ff8500c400c3008800870081c032c02ec02ac026c00fc005009d003d0035" +
"00c00084c02fc02bc027c023c013c00900a2009e006700400033003200be00bd00" +
"450044c031c02dc029c025c00ec004009c003c002f00ba0041c011c007c00cc002" +
"00050004c012c00800160013c00dc003000a00150012000900ff01000074000b00" +
"0403000102000a003a0038000e000d0019001c000b000c001b00180009000a001a" +
"00160017000800060007001400150004000500120013000100020003000f001000" +
"1100230000000d00260024060106020603efef050105020503040104020403eeee" +
"eded030103020303020102020203")
serverName, ok := readServerName(clientHelloMsg)
assertEqual(serverName, "")
assertEqual(ok, true)
}

func TestReadServerNameOk(t *testing.T) {
assertEqual := assert.Equal(t)
// Client hello from:
// openssl s_client -connect google.com:443 -servername google.com
clientHelloMsg, _ := hex.DecodeString(
"0100014803032657cacce41598fa82e5b75061050bc31c5affdba106b8e7431852" +
"24af0fa1aa000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f00" +
"6b006a00390038ff8500c400c3008800870081c032c02ec02ac026c00fc005" +
"009d003d003500c00084c02fc02bc027c023c013c00900a2009e0067004000" +
"33003200be00bd00450044c031c02dc029c025c00ec004009c003c002f00ba" +
"0041c011c007c00cc00200050004c012c00800160013c00dc003000a001500" +
"12000900ff010000870000000f000d00000a676f6f676c652e636f6d000b00" +
"0403000102000a003a0038000e000d0019001c000b000c001b00180009000a" +
"001a0016001700080006000700140015000400050012001300010002000300" +
"0f0010001100230000000d00260024060106020603efef0501050205030401" +
"04020403eeeeeded030103020303020102020203")
serverName, ok := readServerName(clientHelloMsg)
assertEqual(serverName, "google.com")
assertEqual(ok, true)
}

0 comments on commit 75c23f2

Please sign in to comment.