Skip to content

Commit

Permalink
transport: Add test for deny incoming peer connection from wrong ip SAN
Browse files Browse the repository at this point in the history
  • Loading branch information
wenjiaswe committed May 21, 2019
1 parent 06cec40 commit 59763e0
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pkg/transport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (info TLSInfo) Empty() bool {
return info.CertFile == "" && info.KeyFile == ""
}

func SelfCert(dirpath string, hosts []string) (info TLSInfo, err error) {
func SelfCert(dirpath string, hosts []string, additionalUsages ...x509.ExtKeyUsage) (info TLSInfo, err error) {
if err = os.MkdirAll(dirpath, 0700); err != nil {
return
}
Expand Down Expand Up @@ -118,7 +118,7 @@ func SelfCert(dirpath string, hosts []string) (info TLSInfo, err error) {
NotAfter: time.Now().Add(365 * (24 * time.Hour)),

KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
ExtKeyUsage: append([]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, additionalUsages...),
BasicConstraintsValid: true,
}

Expand Down
104 changes: 102 additions & 2 deletions pkg/transport/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,20 @@ import (
"os"
"testing"
"time"
"crypto/x509"
"net"
)

func createSelfCert() (*TLSInfo, func(), error) {
func createSelfCert(hosts ...string) (*TLSInfo, func(), error) {
return createSelfCertEx("127.0.0.1")
}

func createSelfCertEx(host string, additionalUsages ...x509.ExtKeyUsage) (*TLSInfo, func(), error) {
d, terr := ioutil.TempDir("", "etcd-test-tls-")
if terr != nil {
return nil, nil, terr
}
info, err := SelfCert(d, []string{"127.0.0.1"})
info, err := SelfCert(d, []string{host + ":0"}, additionalUsages...)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -74,6 +80,100 @@ func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
}
}

// TestNewListenerTLSInfoClientVerify tests that if client IP address mismatches
// with specified address in its certificate the connection is rejected
func TestNewListenerTLSInfoClientVerify(t *testing.T) {
tests := []struct {
goodClientHost bool
acceptExpected bool
}{
{true, true},
{false, false},
}
for _, test := range tests {
testNewListenerTLSInfoClientCheck(t, test.goodClientHost, test.acceptExpected)
}
}

func testNewListenerTLSInfoClientCheck(t *testing.T, goodClientHost, acceptExpected bool) {
tlsInfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("unable to create cert: %v", err)
}
defer del()

host := "127.0.0.222"
if goodClientHost {
host = "127.0.0.1"
}
clientTLSInfo, del2, err := createSelfCertEx(host, x509.ExtKeyUsageClientAuth)
if err != nil {
t.Fatalf("unable to create cert: %v", err)
}
defer del2()

tlsInfo.CAFile = clientTLSInfo.CertFile

rootCAs := x509.NewCertPool()
loaded, err := ioutil.ReadFile(tlsInfo.CertFile)
if err != nil {
t.Fatalf("unexpected missing certfile: %v", err)
}
rootCAs.AppendCertsFromPEM(loaded)

clientCert, err := tls.LoadX509KeyPair(clientTLSInfo.CertFile, clientTLSInfo.KeyFile)
if err != nil {
t.Fatalf("unable to create peer cert: %v", err)
}

tlsConfig := &tls.Config{}
tlsConfig.InsecureSkipVerify = false
tlsConfig.Certificates = []tls.Certificate{clientCert}
tlsConfig.RootCAs = rootCAs

ln, err := NewListener("127.0.0.1:0", "https", tlsInfo)
if err != nil {
t.Fatalf("unexpected NewListener error: %v", err)
}
defer ln.Close()

tr := &http.Transport{TLSClientConfig: tlsConfig}
cli := &http.Client{Transport: tr}
chClientErr := make(chan error)
go func() {
_, err := cli.Get("https://" + ln.Addr().String())
chClientErr <- err
}()

chAcceptErr := make(chan error)
chAcceptConn := make(chan net.Conn)
go func() {
conn, err := ln.Accept()
if err != nil {
chAcceptErr <- err
} else {
chAcceptConn <- conn
}
}()

select {
case <-chClientErr:
if acceptExpected {
t.Errorf("accepted for good client address: goodClientHost=%t", goodClientHost)
}
case acceptErr := <-chAcceptErr:
t.Fatalf("unexpected Accept error: %v", acceptErr)
case conn := <-chAcceptConn:
defer conn.Close()
if _, ok := conn.(*tls.Conn); !ok {
t.Errorf("failed to accept *tls.Conn")
}
if !acceptExpected {
t.Errorf("accepted for bad client address: goodClientHost=%t", goodClientHost)
}
}
}

func TestNewListenerTLSEmptyInfo(t *testing.T) {
_, err := NewListener("127.0.0.1:0", "https", nil)
if err == nil {
Expand Down
1 change: 1 addition & 0 deletions pkg/transport/listener_tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ func (l *tlsListener) acceptLoop() {
}

st := tlsConn.ConnectionState()
fmt.Printf("st.PeerCertificates len: %d\n", len(st.PeerCertificates) )
if len(st.PeerCertificates) > 0 {
cert := st.PeerCertificates[0]
addr := tlsConn.RemoteAddr().String()
Expand Down

0 comments on commit 59763e0

Please sign in to comment.