Skip to content

Commit

Permalink
Merge pull request #423 from DanSipola/master
Browse files Browse the repository at this point in the history
TCP+SNI support arbitrary large Client Hello
  • Loading branch information
magiconair authored Feb 18, 2018
2 parents 19a51e8 + 9126048 commit 106316f
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 57 deletions.
37 changes: 29 additions & 8 deletions proxy/tcp/sni_proxy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tcp

import (
"bufio"
"io"
"log"
"net"
Expand Down Expand Up @@ -41,20 +42,40 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error {
p.Conn.Inc(1)
}

// capture client hello
data := make([]byte, 1024)
n, err := in.Read(data)
tlsReader := bufio.NewReader(in)
tlsHeaders, err := tlsReader.Peek(9)
if err != nil {
log.Print("[DEBUG] tcp+sni: TLS handshake failed (failed to peek data)")
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
return err
}
data = data[:n]

host, ok := readServerName(data)
bufferSize, err := clientHelloBufferSize(tlsHeaders)
if err != nil {
log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err)
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
return err
}

data := make([]byte, bufferSize)
_, err = io.ReadFull(tlsReader, data)
if err != nil {
log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err)
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
return err
}

// readServerName wants only the handshake message so ignore the first
// 5 bytes which is the TLS record header
host, ok := readServerName(data[5:])
if !ok {
log.Print("[DEBUG] tcp+sni: TLS handshake failed")
log.Print("[DEBUG] tcp+sni: TLS handshake failed (unable to parse client hello)")
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
Expand Down Expand Up @@ -92,8 +113,8 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error {
}
defer out.Close()

// copy client hello
n, err = out.Write(data)
// write the data already read from the connection
n, err := out.Write(data)
if err != nil {
log.Print("[WARN] tcp+sni: copy client hello failed. ", err)
if p.ConnFail != nil {
Expand Down
90 changes: 41 additions & 49 deletions proxy/tcp/tls_clienthello.go
Original file line number Diff line number Diff line change
@@ -1,73 +1,65 @@
package tcp

// record types
const (
handshakeRecord = 0x16
clientHelloType = 0x01
)

// readServerName returns the server name from a TLS ClientHello message which
// has the server_name extension (SNI). ok is set to true if the ClientHello
// message was parsed successfully. If the server_name extension was not set
// and empty string is returned as serverName.
func readServerName(data []byte) (serverName string, ok bool) {
if m, ok := readClientHello(data); ok {
return m.serverName, true
}
return "", false
}

// readClientHello
func readClientHello(data []byte) (m *clientHelloMsg, ok bool) {
if len(data) < 9 {
// println("buf too short")
return nil, false
}
import "errors"

// Determines the required size of a buffer large enough to hold
// a client hello message including the tls record header and the
// handshake message header.
// The function requires at least the first 9 bytes of the tls conversation
// in "data".
// An error is returned if the data does not follow the
// specification (https://tools.ietf.org/html/rfc5246) or if the client hello
// is fragmented over multiple records.
func clientHelloBufferSize(data []byte) (int, error) {
// TLS record header
// -----------------
// byte 0: rec type (should be 0x16 == Handshake)
// byte 1-2: version (should be 0x3000 < v < 0x3003)
// byte 3-4: rec len
recType := data[0]
if recType != handshakeRecord {
// println("no handshake ")
return nil, false
if len(data) < 9 {
return 0, errors.New("At least 9 bytes required to determine client hello length")
}

recLen := int(data[3])<<8 | int(data[4])
if recLen == 0 || recLen > len(data)-5 {
// println("rec too short")
return nil, false
if data[0] != 0x16 {
return 0, errors.New("Not a TLS handshake")
}

recordLength := int(data[3])<<8 | int(data[4])
if recordLength <= 0 || recordLength > 16384 {
return 0, errors.New("Invalid TLS record length")
}

// Handshake record header
// -----------------------
// byte 5: hs msg type (should be 0x01 == client_hello)
// byte 6-8: hs msg len
hsType := data[5]
if hsType != clientHelloType {
// println("no client_hello")
return nil, false
if data[5] != 0x01 {
return 0, errors.New("Not a client hello")
}

hsLen := int(data[6])<<16 | int(data[7])<<8 | int(data[8])
if hsLen == 0 || hsLen > len(data)-9 {
// println("handshake rec too short")
return nil, false
handshakeLength := int(data[6])<<16 | int(data[7])<<8 | int(data[8])
if handshakeLength <= 0 || handshakeLength > recordLength-4 {
return 0, errors.New("Invalid client hello length (fragmentation not implemented)")
}

// byte 9- : client hello msg
//
// m.unmarshal parses the entire handshake message and
// not just the client hello. Therefore, we need to pass
// data from byte 5 instead of byte 9. (see comment below)
m = new(clientHelloMsg)
if !m.unmarshal(data[5:]) {
// println("client_hello unmarshal failed")
return nil, false
return handshakeLength + 9, nil //9 for the header bytes
}

// readServerName returns the server name from a TLS ClientHello message which
// has the server_name extension (SNI). ok is set to true if the ClientHello
// message was parsed successfully. If the server_name extension was not set
// an empty string is returned as serverName.
// clientHelloHandshakeMsg must contain the full client hello handshake
// message including the 4 byte header.
// See: https://www.ietf.org/rfc/rfc5246.txt
func readServerName(clientHelloHandshakeMsg []byte) (serverName string, ok bool) {
m := new(clientHelloMsg)
if !m.unmarshal(clientHelloHandshakeMsg) {
//println("client_hello unmarshal failed")
return "", false
}
return m, true

return m.serverName, true
}

// The code below is a verbatim copy from go1.7/src/crypto/tls/handshake_messages.go
Expand Down
162 changes: 162 additions & 0 deletions proxy/tcp/tls_clienthello_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package tcp

import (
"encoding/hex"
"testing"
)

func TestClientHelloBufferSize(t *testing.T) {
tests := []struct {
name string
data []byte
size int
fail bool
}{
{
name: "valid data",
// Largest possible client hello message
// |- 16384 -| |----- 16380 ----|
data: []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x3f, 0xfc},
size: 16384 + 5, // max record length + record header
fail: false,
},
{
name: "not enough data",
data: []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x3f},
size: 0,
fail: true,
},
{
name: "not a TLS record",
data: []byte{0x15, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xeb},
size: 0,
fail: true,
},

{
name: "TLS record too large",
// | max + 1 |
data: []byte{0x16, 0x03, 0x01, 0x40, 0x01, 0x01, 0x00, 0x3f, 0xfc},
size: 0,
fail: true,
},

{
name: "TLS record length zero",
// |----------|
data: []byte{0x16, 0x03, 0x01, 0x00, 0x00, 0x01, 0x00, 0x3f, 0xfc},
size: 0,
fail: true,
},

{
name: "Not a client hello",
// |----|
data: []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x02, 0x00, 0x3f, 0xfc},
size: 0,
fail: true,
},

{
name: "Invalid handshake message record length",
// |----- 0 --------|
data: []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x00, 0x00},
size: 0,
fail: true,
},

{
name: "Fragmentation (handshake message larger than record)",
// |- 500 ---| |----- 497 ------|
data: []byte{0x16, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xf1},
size: 0,
fail: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := clientHelloBufferSize(tt.data)

if tt.fail && err == nil {
t.Fatal("expected error, got nil")
} else if !tt.fail && err != nil {
t.Fatalf("expected error to be nil, got %s", err)
}

if want := tt.size; got != want {
t.Fatalf("want size %d, got %d", want, got)
}
})
}
}

func TestReadServerName(t *testing.T) {
tests := []struct {
name string
servername string
ok bool
data string //Hex string, decoded by test
}{
{
// Client hello from:
// openssl s_client -connect google.com:443 -servername google.com
name: "valid client hello with server name",
servername: "google.com",
ok: true,
data: "0100014803032657cacce41598fa82e5b75061050bc31c5affdba106b8e7431852" +
"24af0fa1aa000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f00" +
"6b006a00390038ff8500c400c3008800870081c032c02ec02ac026c00fc005" +
"009d003d003500c00084c02fc02bc027c023c013c00900a2009e0067004000" +
"33003200be00bd00450044c031c02dc029c025c00ec004009c003c002f00ba" +
"0041c011c007c00cc00200050004c012c00800160013c00dc003000a001500" +
"12000900ff010000870000000f000d00000a676f6f676c652e636f6d000b00" +
"0403000102000a003a0038000e000d0019001c000b000c001b00180009000a" +
"001a0016001700080006000700140015000400050012001300010002000300" +
"0f0010001100230000000d00260024060106020603efef0501050205030401" +
"04020403eeeeeded030103020303020102020203",
},
{
// Client hello from:
// openssl s_client -connect google.com:443
name: "valid client hello but no server name extension",
servername: "",
ok: true,
data: "0100013503036dfb09de7b16503dd1bb304dcbe54079913b65abf53de997f73b26c99e" +
"67ba28000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f006b006a00" +
"390038ff8500c400c3008800870081c032c02ec02ac026c00fc005009d003d0035" +
"00c00084c02fc02bc027c023c013c00900a2009e006700400033003200be00bd00" +
"450044c031c02dc029c025c00ec004009c003c002f00ba0041c011c007c00cc002" +
"00050004c012c00800160013c00dc003000a00150012000900ff01000074000b00" +
"0403000102000a003a0038000e000d0019001c000b000c001b00180009000a001a" +
"00160017000800060007001400150004000500120013000100020003000f001000" +
"1100230000000d00260024060106020603efef050105020503040104020403eeee" +
"eded030103020303020102020203",
},
{
name: "invalid client hello",
servername: "",
ok: false,
data: "0100014c5768656e2070656f706c652073617920746f206d653a20776f756c6420796f" +
"75207261746865722062652074686f75676874206f6620617320612066756e6e79" +
"206d616e206f72206120677265617420626f73733f204d7920616e737765722773" +
"20616c77617973207468652073616d652c20746f206d652c207468657927726520" +
"6e6f74206d757475616c6c79206578636c75736976652e2d204461766964204272" +
"656e74",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientHelloMsg, _ := hex.DecodeString(tt.data)
servername, ok := readServerName(clientHelloMsg)
if got, want := servername, tt.servername; got != want {
t.Fatalf("%s: got servername \"%s\" want \"%s\"", tt.name, got, want)
}

if got, want := ok, tt.ok; got != want {
t.Fatalf("%s: got ok %t want %t", tt.name, got, want)
}
})
}
}

0 comments on commit 106316f

Please sign in to comment.