From 65e743c5fc59df0916f5c58e01e757fd5be5a847 Mon Sep 17 00:00:00 2001 From: Frank Schroeder Date: Tue, 21 Feb 2017 19:51:09 +0100 Subject: [PATCH] Issue #179: TCP Proxy Support * Add generic TCP proxy support. * Add support for ReadTimeout and WriteTimeout for the TCP and the TCP+SNI proxy. * Add integration tests for the TCP and TCP+SNI proxy. * Update the demo server to provide a TCP server. * Add a tcptest package for generic TCP server testing. Fixes #178, #179 --- config/load.go | 9 +- config/load_test.go | 15 +++- demo/server/server.go | 164 +++++++++++++++++++++++----------- fabio.properties | 10 +-- main.go | 34 ++++--- proxy/http_proxy.go | 7 -- proxy/internal/testcert.go | 37 ++++++++ proxy/listen.go | 2 +- proxy/serve.go | 10 +++ proxy/tcp/server.go | 44 +++++++-- proxy/tcp/sni_proxy.go | 27 +++--- proxy/tcp/tcp_proxy.go | 52 +++++++++++ proxy/tcp/tcptest/dialer.go | 66 ++++++++++++++ proxy/tcp/tcptest/server.go | 106 ++++++++++++++++++++++ proxy/tcp_integration_test.go | 109 ++++++++++++++++++++++ registry/consul/parse.go | 6 ++ 16 files changed, 591 insertions(+), 107 deletions(-) create mode 100644 proxy/internal/testcert.go create mode 100644 proxy/tcp/tcp_proxy.go create mode 100644 proxy/tcp/tcptest/dialer.go create mode 100644 proxy/tcp/tcptest/server.go create mode 100644 proxy/tcp_integration_test.go diff --git a/config/load.go b/config/load.go index c6dd3cac9..e553a6cd8 100644 --- a/config/load.go +++ b/config/load.go @@ -267,7 +267,10 @@ func parseListen(cfg string, cs map[string]CertSource, readTimeout, writeTimeout switch k { case "proto": l.Proto = v - if l.Proto != "http" && l.Proto != "https" && l.Proto != "tcp+sni" { + switch l.Proto { + case "tcp", "tcp+sni", "http", "https": + // ok + default: return Listen{}, fmt.Errorf("unknown protocol %q", v) } case "rt": // read timeout @@ -300,8 +303,8 @@ func parseListen(cfg string, cs map[string]CertSource, readTimeout, writeTimeout if l.Proto == "" { l.Proto = "http" } - if csName != "" && l.Proto != "https" { - return Listen{}, fmt.Errorf("cert source requires proto 'https'") + if csName != "" && l.Proto != "https" && l.Proto != "tcp" { + return Listen{}, fmt.Errorf("cert source requires proto 'https' or 'tcp'") } if csName == "" && l.Proto == "https" { return Listen{}, fmt.Errorf("proto 'https' requires cert source") diff --git a/config/load_test.go b/config/load_test.go index cf2ab56c7..f9d2bf898 100644 --- a/config/load_test.go +++ b/config/load_test.go @@ -62,6 +62,13 @@ func TestLoad(t *testing.T) { return cfg }, }, + { + args: []string{"-proxy.addr", ":5555;proto=tcp"}, + cfg: func(cfg *Config) *Config { + cfg.Listen = []Listen{{Addr: ":5555", Proto: "tcp"}} + return cfg + }, + }, { args: []string{"-proxy.addr", ":5555;proto=tcp+sni"}, cfg: func(cfg *Config) *Config { @@ -686,16 +693,16 @@ func TestLoad(t *testing.T) { err: errors.New("proto 'https' requires cert source"), }, { - desc: "-proxy.addr with cert source and proto 'http' requires proto 'https'", + desc: "-proxy.addr with cert source and proto 'http' requires proto 'https' or 'tcp'", args: []string{"-proxy.addr", ":5555;cs=name;proto=http", "-proxy.cs", "cs=name;type=path;cert=value"}, cfg: func(cfg *Config) *Config { return nil }, - err: errors.New("cert source requires proto 'https'"), + err: errors.New("cert source requires proto 'https' or 'tcp'"), }, { - desc: "-proxy.addr with cert source and proto 'tcp+sni' requires proto 'https'", + desc: "-proxy.addr with cert source and proto 'tcp+sni' requires proto 'https' or 'tcp'", args: []string{"-proxy.addr", ":5555;cs=name;proto=tcp+sni", "-proxy.cs", "cs=name;type=path;cert=value"}, cfg: func(cfg *Config) *Config { return nil }, - err: errors.New("cert source requires proto 'https'"), + err: errors.New("cert source requires proto 'https' or 'tcp'"), }, { args: []string{"-cfg"}, diff --git a/demo/server/server.go b/demo/server/server.go index 8122a9b35..b2610a692 100644 --- a/demo/server/server.go +++ b/demo/server/server.go @@ -24,9 +24,13 @@ // # websocket server // ./server -addr 127.0.0.1:6000 -name ws-a -prefix /echo1,/echo2 -proto ws // +// # tcp server +// ./server -addr 127.0.0.1:7000 -name tcp-a -proto tcp +// package main import ( + "bufio" "flag" "fmt" "io" @@ -39,10 +43,49 @@ import ( "strconv" "strings" + "github.com/eBay/fabio/proxy/tcp" "github.com/hashicorp/consul/api" "golang.org/x/net/websocket" ) +type Server interface { + // embedded server methods + ListenAndServe() error + ListenAndServeTLS(certFile, keyFile string) error + + // consul register helpers + Tags() []string + Check() *api.AgentServiceCheck +} + +type HTTPServer struct { + *http.Server + tags []string + check *api.AgentServiceCheck +} + +func (s *HTTPServer) Check() *api.AgentServiceCheck { + return s.check +} + +func (s *HTTPServer) Tags() []string { + return s.tags +} + +type TCPServer struct { + *tcp.Server + tags []string + check *api.AgentServiceCheck +} + +func (s *TCPServer) Check() *api.AgentServiceCheck { + return s.check +} + +func (s *TCPServer) Tags() []string { + return s.tags +} + func main() { var addr, consul, name, prefix, proto, token string var certFile, keyFile string @@ -50,8 +93,8 @@ func main() { flag.StringVar(&addr, "addr", "127.0.0.1:5000", "host:port of the service") flag.StringVar(&consul, "consul", "127.0.0.1:8500", "host:port of the consul agent") flag.StringVar(&name, "name", filepath.Base(os.Args[0]), "name of the service") - flag.StringVar(&prefix, "prefix", "", "comma-sep list of host/path prefixes to register") - flag.StringVar(&proto, "proto", "http", "protocol for endpoints: http or ws") + flag.StringVar(&prefix, "prefix", "", "comma-sep list of 'host/path' or ':port' prefixes to register") + flag.StringVar(&proto, "proto", "http", "protocol for endpoints: http, ws or tcp") flag.StringVar(&token, "token", "", "consul ACL token") flag.StringVar(&certFile, "cert", "", "path to cert file") flag.StringVar(&keyFile, "key", "", "path to key file") @@ -63,49 +106,62 @@ func main() { os.Exit(1) } - // register prefixes - prefixes := strings.Split(prefix, ",") - for _, p := range prefixes { - switch proto { - case "http": - http.HandleFunc(p, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(status) - fmt.Fprintf(w, "Serving %s from %s on %s\n", r.RequestURI, name, addr) - }) - case "ws": - http.Handle(p, websocket.Handler(EchoServer)) - default: - log.Fatal("Invalid protocol ", proto) + var srv Server + switch proto { + case "http", "ws": + mux := http.NewServeMux() + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "OK") + }) + + var tags []string + for _, p := range strings.Split(prefix, ",") { + tags = append(tags, "urlprefix-"+p) + switch proto { + case "http": + mux.HandleFunc(p, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(status) + fmt.Fprintf(w, "Serving %s from %s on %s\n", r.RequestURI, name, addr) + }) + case "ws": + mux.Handle(p, websocket.Handler(WSEchoServer)) + } } - } - // register consul health check endpoint - http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "OK") - }) + var check *api.AgentServiceCheck + if certFile != "" { + check = &api.AgentServiceCheck{TCP: addr, Interval: "2s", Timeout: "1s"} + } else { + check = &api.AgentServiceCheck{HTTP: "http://" + addr + "/health", Interval: "1s", Timeout: "1s"} + } - // start http server - go func() { - log.Printf("Listening on %s serving %s", addr, prefix) + srv = &HTTPServer{&http.Server{Addr: addr, Handler: mux}, tags, check} + + case "tcp": + var tags []string + for _, p := range strings.Split(prefix, ",") { + tags = append(tags, "urlprefix-"+p+" proto=tcp") + } + check := &api.AgentServiceCheck{TCP: addr, Interval: "2s", Timeout: "1s"} + srv = &TCPServer{&tcp.Server{Addr: addr, Handler: tcp.HandlerFunc(TCPEchoHandler)}, tags, check} + + default: + log.Fatal("Invalid protocol ", proto) + } + // start server + go func() { var err error if certFile != "" { - err = http.ListenAndServeTLS(addr, certFile, keyFile, nil) + err = srv.ListenAndServeTLS(certFile, keyFile) } else { - err = http.ListenAndServe(addr, nil) + err = srv.ListenAndServe() } if err != nil { log.Fatal(err) } }() - // build urlprefix-host/path tag list - // e.g. urlprefix-/foo, urlprefix-/bar, ... - var tags []string - for _, p := range prefixes { - tags = append(tags, "urlprefix-"+p) - } - // get host and port as string/int host, portstr, err := net.SplitHostPort(addr) if err != nil { @@ -116,21 +172,6 @@ func main() { log.Fatal(err) } - var check *api.AgentServiceCheck - if certFile != "" { - check = &api.AgentServiceCheck{ - TCP: addr, - Interval: "2s", - Timeout: "1s", - } - } else { - check = &api.AgentServiceCheck{ - HTTP: "http://" + addr + "/health", - Interval: "1s", - Timeout: "1s", - } - } - // register service with health check serviceID := name + "-" + addr service := &api.AgentServiceRegistration{ @@ -138,8 +179,8 @@ func main() { Name: name, Port: port, Address: host, - Tags: tags, - Check: check, + Tags: srv.Tags(), + Check: srv.Check(), } config := &api.Config{Address: consul, Scheme: "http", Token: token} @@ -151,7 +192,7 @@ func main() { if err := client.Agent().ServiceRegister(service); err != nil { log.Fatal(err) } - log.Printf("Registered service %q in consul with tags %q", name, strings.Join(tags, ",")) + log.Printf("Registered %s service %q in consul with tags %q", proto, name, strings.Join(srv.Tags(), ",")) // run until we get a signal quit := make(chan os.Signal, 1) @@ -165,7 +206,7 @@ func main() { log.Printf("Deregistered service %q in consul", name) } -func EchoServer(ws *websocket.Conn) { +func WSEchoServer(ws *websocket.Conn) { addr := ws.LocalAddr().String() pfx := []byte("[" + addr + "] ") @@ -188,3 +229,24 @@ func EchoServer(ws *websocket.Conn) { } log.Printf("ws disconnect on %s", addr) } + +func TCPEchoHandler(c net.Conn) error { + defer c.Close() + + addr := c.LocalAddr().String() + _, err := fmt.Fprintf(c, "[%s] Welcome\n", addr) + if err != nil { + return err + } + + for { + line, _, err := bufio.NewReader(c).ReadLine() + if err != nil { + return err + } + _, err = fmt.Fprintf(c, "[%s] %s\n", addr, string(line)) + if err != nil { + return err + } + } +} diff --git a/fabio.properties b/fabio.properties index 80c94f4f6..65ae4d2ed 100644 --- a/fabio.properties +++ b/fabio.properties @@ -177,7 +177,8 @@ # # * http for HTTP based protocols # * https for HTTPS based protocols -# * tcp+sni for an SNI aware TCP proxy (EXPERIMENTAL) +# * tcp for a raw TCP proxy with or witout TLS support +# * tcp+sni for an SNI aware TCP proxy # # If no 'proto' option is specified then the protocol # is either 'http' or 'https' depending on whether a @@ -189,10 +190,6 @@ # extension and then forwards the encrypted traffic # to the destination without decrypting the traffic. # -# The TCP+SNI proxy is currently marked as EXPERIMENTAL -# since it needs more real-world testing and an integration -# test. -# # General options: # # rt: Sets the read timeout as a duration value (e.g. '3s') @@ -223,6 +220,9 @@ # # HTTPS listener on port 443 with certificate source # proxy.addr = :443;cs=some-name # +# # TCP listener on port 1234 with port routing +# proxy.addr = :1234;proto=tcp +# # # TCP listener on port 443 with SNI routing # proxy.addr = :443;proto=tcp+sni # diff --git a/main.go b/main.go index 5549a2e7c..7c8f42854 100644 --- a/main.go +++ b/main.go @@ -36,7 +36,7 @@ import ( // It is also set by the linker when fabio // is built via the Makefile or the build/docker.sh // script to ensure the correct version nubmer -var version = "1.3.8" +var version = "1.4beta1" var shuttingDown int32 @@ -81,6 +81,7 @@ func main() { func newHTTPProxy(cfg *config.Config) http.Handler { pick := route.Picker[cfg.Proxy.Strategy] match := route.Matcher[cfg.Proxy.Matcher] + notFound := metrics.DefaultRegistry.GetCounter("notfound") log.Printf("[INFO] Using routing strategy %q", cfg.Proxy.Strategy) log.Printf("[INFO] Using route matching %q", cfg.Proxy.Matcher) @@ -97,26 +98,26 @@ func newHTTPProxy(cfg *config.Config) http.Handler { Lookup: func(r *http.Request) *route.Target { t := route.GetTable().Lookup(r, r.Header.Get("trace"), pick, match) if t == nil { + notFound.Inc(1) log.Print("[WARN] No route for ", r.Host, r.URL) } return t }, Requests: metrics.DefaultRegistry.GetTimer("requests"), - Noroute: metrics.DefaultRegistry.GetCounter("notfound"), } } -func newTCPSNIProxy(cfg *config.Config) *tcp.SNIProxy { +func lookupHostFn(cfg *config.Config) func(string) string { pick := route.Picker[cfg.Proxy.Strategy] - return &tcp.SNIProxy{ - Config: cfg.Proxy, - Lookup: func(host string) *route.Target { - t := route.GetTable().LookupHost(host, pick) - if t == nil { - log.Print("[WARN] No route for ", host) - } - return t - }, + notFound := metrics.DefaultRegistry.GetCounter("notfound") + return func(host string) string { + t := route.GetTable().LookupHost(host, pick) + if t == nil { + notFound.Inc(1) + log.Print("[WARN] No route for ", host) + return "" + } + return t.URL.Host } } @@ -140,9 +141,14 @@ func startServers(cfg *config.Config) { for _, l := range cfg.Listen { switch l.Proto { case "http", "https": - go proxy.ListenAndServeHTTP(l, newHTTPProxy(cfg)) + h := newHTTPProxy(cfg) + go proxy.ListenAndServeHTTP(l, h) + case "tcp": + h := &tcp.Proxy{cfg.Proxy.DialTimeout, lookupHostFn(cfg)} + go proxy.ListenAndServeTCP(l, h) case "tcp+sni": - go proxy.ListenAndServeTCP(l, newTCPSNIProxy(cfg)) + h := &tcp.SNIProxy{cfg.Proxy.DialTimeout, lookupHostFn(cfg)} + go proxy.ListenAndServeTCP(l, h) default: exit.Fatal("[FATAL] Invalid protocol ", l.Proto) } diff --git a/proxy/http_proxy.go b/proxy/http_proxy.go index faad70b5c..d521080f5 100644 --- a/proxy/http_proxy.go +++ b/proxy/http_proxy.go @@ -26,10 +26,6 @@ type HTTPProxy struct { // Requests is a timer metric which is updated for every request. Requests metrics.Timer - - // Noroute is a counter metric which is updated for every request - // where Lookup() returns nil. - Noroute metrics.Counter } func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -39,9 +35,6 @@ func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { t := p.Lookup(r) if t == nil { - if p.Noroute != nil { - p.Noroute.Inc(1) - } w.WriteHeader(p.Config.NoRouteStatus) return } diff --git a/proxy/internal/testcert.go b/proxy/internal/testcert.go new file mode 100644 index 000000000..bb6a9ca73 --- /dev/null +++ b/proxy/internal/testcert.go @@ -0,0 +1,37 @@ +package internal + +// LocalhostCert is a PEM-encoded TLS cert with SAN IPs +// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. +// generated from src/crypto/tls: +// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var LocalhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIICEzCCAXygAwIBAgIQS3cofn+2H4NxFntgaMRAPTANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw +MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB +iQKBgQDp+sQVBNYwZ4YSskddAtTYq2NPdWYawNw9YQDBU9ft3fIm1r9UoyL/57bo +gCgFAkglXo06sAfuk+W6OXRPplEwxCU/mAiAjMLKES1V3oZnI42sTeiskdvb8j6E +47EpbWSA2OU4Nqulbh6vkGrzYzUdlmwwz+rGvfmHp1EOjMVzvQIDAQABo2gwZjAO +BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw +AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA +AAAAATANBgkqhkiG9w0BAQsFAAOBgQBChdgkaHaw83GFx8aDWoE3K4+h9YqXuvEP +b2OWAYlzY/U99BA9P0lE4vGpaIAeCFxalJ2AK3yHjt+eezy3sw0bMeG8ZNYcOyIV +exS95UdAKFt93a5zIWrkYQvhuzln1IOxPJQZ4rkq4nikLj2WuyGR7QnuVBdgPqP7 +RN4BPb5Sog== +-----END CERTIFICATE-----`) + +// LocalhostKey is the private key for LocalhostCert. +var LocalhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQDp+sQVBNYwZ4YSskddAtTYq2NPdWYawNw9YQDBU9ft3fIm1r9U +oyL/57bogCgFAkglXo06sAfuk+W6OXRPplEwxCU/mAiAjMLKES1V3oZnI42sTeis +kdvb8j6E47EpbWSA2OU4Nqulbh6vkGrzYzUdlmwwz+rGvfmHp1EOjMVzvQIDAQAB +AoGAGVoXduOQRaxh5ZK1kslkwJlJaGmjB5EQDAJ/r3LjOZ3LyBOKpaQLfcjgk66X +J3vIz2vAR7SdF2elA5mIFb1CnJ4HW4cWHzgFQdUnUtoUNuMPy/9QREFfeag9GMPx +dZNiypiKqHDSY5ovUL92gtv5W0/w00lYpFiBaYLl+WHvQ6ECQQDvZpULZCEmZHwL +hZun4ObzLwFNZ9sNPgwJybnxVYaolXACeh4Ewur0kZlY9DJMqo7Rz82JWuFarkgU +GQK/L231AkEA+jP0+q7jfI8NJqwpWFDjwKiI7fadClcdUgXvW2c5wc2pEe4KiAqs +ZOWPGsH7SxigGRLzw01SCoInX5yw689JqQJARIOTPENXyWkQpyuBtLYE4qwdL039 +vvh28YYuFQdpFm5ONCdG2A4AuCXDQVYB3zcg0KMsK5c6z3z5W+cchiLI0QJBAIDS +ZYz4pNoKEVxbAgKdy1XzsGTNN/gN+GO1+JJYKK23RRidNkDrNe3RIAhH3inBKRUf +4/AnjFkqwDkDRTh0htkCQQDfrRZr+gazwzDTSp23+l6MEbqBbc+TTC3c40zpNj4a +egxjd5+SkMj6zXEJxAOgo+LmQDGWsu1YQ+XXL87VPwIP +-----END RSA PRIVATE KEY-----`) diff --git a/proxy/listen.go b/proxy/listen.go index 42cce6188..0696ffbc8 100644 --- a/proxy/listen.go +++ b/proxy/listen.go @@ -6,7 +6,7 @@ import ( "net" "time" - "github.com/armon/go-proxyproto" + proxyproto "github.com/armon/go-proxyproto" "github.com/eBay/fabio/cert" "github.com/eBay/fabio/config" ) diff --git a/proxy/serve.go b/proxy/serve.go index aa388720e..0d94fb9d6 100644 --- a/proxy/serve.go +++ b/proxy/serve.go @@ -24,10 +24,20 @@ var ( servers []Server ) +func Close() { + mu.Lock() + for _, srv := range servers { + srv.Close() + } + servers = []Server{} + mu.Unlock() +} + func Shutdown(timeout time.Duration) { mu.Lock() srvs := make([]Server, len(servers)) copy(srvs, servers) + servers = []Server{} mu.Unlock() var wg sync.WaitGroup diff --git a/proxy/tcp/server.go b/proxy/tcp/server.go index 72a9085ed..69d8c6e95 100644 --- a/proxy/tcp/server.go +++ b/proxy/tcp/server.go @@ -2,15 +2,26 @@ package tcp import ( "context" + "crypto/tls" "net" "sync" "time" ) +// Handler responds to a TCP request. +// +// ServeTCP should write responses to the in connection and close +// it on return. type Handler interface { ServeTCP(in net.Conn) error } +type HandlerFunc func(in net.Conn) error + +func (f HandlerFunc) ServeTCP(in net.Conn) error { + return f(in) +} + // Server implements a generic TCP server. type Server struct { Addr string @@ -20,7 +31,30 @@ type Server struct { mu sync.Mutex listeners []net.Listener - conns map[int64]net.Conn + conns map[net.Conn]bool +} + +func (s *Server) ListenAndServe() error { + l, err := net.Listen("tcp", s.Addr) + if err != nil { + return err + } + defer l.Close() + return s.Serve(l) +} + +func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + cfg := &tls.Config{Certificates: []tls.Certificate{cert}} + l, err := tls.Listen("tcp", s.Addr, cfg) + if err != nil { + return err + } + defer l.Close() + return s.Serve(l) } func (s *Server) Serve(l net.Listener) error { @@ -36,16 +70,15 @@ func (s *Server) Serve(l net.Listener) error { return err } c = &conn{ - id: time.Now().UnixNano(), c: c, ReadTimeout: s.ReadTimeout, WriteTimeout: s.WriteTimeout, } s.mu.Lock() if s.conns == nil { - s.conns = map[int64]net.Conn{} + s.conns = map[net.Conn]bool{} } - s.conns[c.(*conn).id] = c + s.conns[c] = true s.mu.Unlock() go s.Handler.ServeTCP(c) } @@ -63,7 +96,7 @@ func (s *Server) closeListeners() error { func (s *Server) closeConns() error { s.mu.Lock() - for _, c := range s.conns { + for c := range s.conns { c.Close() } s.conns = nil @@ -86,7 +119,6 @@ func (s *Server) Shutdown(ctx context.Context) error { // conn implements a connection which honors read and write timeouts. type conn struct { - id int64 c net.Conn ReadTimeout time.Duration WriteTimeout time.Duration diff --git a/proxy/tcp/sni_proxy.go b/proxy/tcp/sni_proxy.go index c49dab4a4..28c8b57d5 100644 --- a/proxy/tcp/sni_proxy.go +++ b/proxy/tcp/sni_proxy.go @@ -1,13 +1,10 @@ package tcp import ( - "fmt" "io" "log" "net" - - "github.com/eBay/fabio/config" - "github.com/eBay/fabio/route" + "time" ) // SNIProxy implements an SNI aware transparent TCP proxy which captures the @@ -16,12 +13,13 @@ import ( // transparently allowing to route a TLS connection based on the SNI header // without decrypting it. type SNIProxy struct { - // Config is the proxy configuration as provided during startup. - Config config.Proxy + // DialTimeout sets the timeout for establishing the outbound + // connection. + DialTimeout time.Duration // Lookup returns a target host for the given server name. // The proxy will panic if this value is nil. - Lookup func(string) *route.Target + Lookup func(host string) string } func (p *SNIProxy) ServeTCP(in net.Conn) error { @@ -35,28 +33,25 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error { } data = data[:n] - serverName, ok := readServerName(data) + host, ok := readServerName(data) if !ok { - fmt.Fprintln(in, "handshake failed") log.Print("[DEBUG] tcp+sni: TLS handshake failed") return nil } - if serverName == "" { - fmt.Fprintln(in, "server_name missing") + if host == "" { log.Print("[DEBUG] tcp+sni: server_name missing") return nil } - t := p.Lookup(serverName) - if t == nil { - log.Print("[WARN] tcp+sni: No route for ", serverName) + addr := p.Lookup(host) + if addr == "" { return nil } - out, err := net.DialTimeout("tcp", t.URL.Host, p.Config.DialTimeout) + out, err := net.DialTimeout("tcp", addr, p.DialTimeout) if err != nil { - log.Print("[WARN] tcp+sni: cannot connect to upstream ", t.URL.Host) + log.Print("[WARN] tcp+sni: cannot connect to upstream ", addr) return err } defer out.Close() diff --git a/proxy/tcp/tcp_proxy.go b/proxy/tcp/tcp_proxy.go new file mode 100644 index 000000000..b97b25873 --- /dev/null +++ b/proxy/tcp/tcp_proxy.go @@ -0,0 +1,52 @@ +package tcp + +import ( + "io" + "log" + "net" + "time" +) + +// Proxy implements a generic TCP proxying handler. +type Proxy struct { + // DialTimeout sets the timeout for establishing the outbound + // connection. + DialTimeout time.Duration + + // Lookup returns a target host for the given server name. + // The proxy will panic if this value is nil. + Lookup func(host string) string +} + +func (p *Proxy) ServeTCP(in net.Conn) error { + defer in.Close() + + _, port, _ := net.SplitHostPort(in.LocalAddr().String()) + port = ":" + port + addr := p.Lookup(port) + if addr == "" { + return nil + } + + out, err := net.DialTimeout("tcp", addr, p.DialTimeout) + if err != nil { + log.Print("[WARN] tcp: cannot connect to upstream ", addr) + return err + } + defer out.Close() + + errc := make(chan error, 2) + cp := func(dst io.Writer, src io.Reader) { + _, err := io.Copy(dst, src) + errc <- err + } + + go cp(out, in) + go cp(in, out) + err = <-errc + if err != nil && err != io.EOF { + log.Print("[WARN]: tcp: ", err) + return err + } + return nil +} diff --git a/proxy/tcp/tcptest/dialer.go b/proxy/tcp/tcptest/dialer.go new file mode 100644 index 000000000..3ce8dacd1 --- /dev/null +++ b/proxy/tcp/tcptest/dialer.go @@ -0,0 +1,66 @@ +package tcptest + +import ( + "crypto/tls" + "net" + "time" +) + +type Dialer interface { + Dial(network, addr string) (net.Conn, error) +} + +func NewRetryDialer() *RetryDialer { + return &RetryDialer{} +} + +// RetryDialer retries the Dial function until it succeeds or +// the timeout has been reached. The default timeout is one +// second and the default sleep interval is 100ms. +type RetryDialer struct { + Dialer net.Dialer + Timeout time.Duration + Sleep time.Duration +} + +func (d *RetryDialer) Dial(network, addr string) (c net.Conn, err error) { + dial := func() (net.Conn, error) { return d.Dialer.Dial(network, addr) } + return retry(dial, d.Timeout, d.Sleep) +} + +func NewTLSRetryDialer(cfg *tls.Config) *TLSRetryDialer { + return &TLSRetryDialer{TLS: cfg} +} + +type TLSRetryDialer struct { + TLS *tls.Config + Dialer net.Dialer + Timeout time.Duration + Sleep time.Duration +} + +func (d *TLSRetryDialer) Dial(network, addr string) (c net.Conn, err error) { + dial := func() (net.Conn, error) { return tls.Dial(network, addr, d.TLS) } + return retry(dial, d.Timeout, d.Sleep) +} + +type dialer func() (net.Conn, error) + +func retry(dial dialer, timeout, sleep time.Duration) (c net.Conn, err error) { + if sleep == 0 { + sleep = 100 * time.Millisecond + } + if timeout == 0 { + timeout = time.Second + } + deadline := time.Now().Add(timeout) + + for { + c, err = dial() + if err != nil && time.Now().Before(deadline) { + time.Sleep(sleep) + continue + } + return + } +} diff --git a/proxy/tcp/tcptest/server.go b/proxy/tcp/tcptest/server.go new file mode 100644 index 000000000..c457b0187 --- /dev/null +++ b/proxy/tcp/tcptest/server.go @@ -0,0 +1,106 @@ +package tcptest + +import ( + "crypto/tls" + "fmt" + "net" + + "github.com/eBay/fabio/proxy/internal" + "github.com/eBay/fabio/proxy/tcp" +) + +// Server is a TCP test server that binds to a random port. +type Server struct { + // Addr is the address the server is listening on in the form ipaddr:port. + Addr string + Listener net.Listener + + // TLS is the optional TLS configuration, populated with a new config + // after TLS is started. If set on an unstarted server before StartTLS + // is called, existing fields are copied into the new config. + TLS *tls.Config + + // Config may be changed after calling NewUnstartedServer and + // before Start or StartTLS. + Config *tcp.Server + + // srv is the actual running server. + srv *tcp.Server +} + +func (s *Server) Start() { + if s.Addr != "" { + panic("Server already started") + } + + s.Addr = s.Listener.Addr().String() + s.srv = new(tcp.Server) + *s.srv = *s.Config + s.srv.Addr = s.Addr + go s.srv.Serve(s.Listener) +} + +func (s *Server) StartTLS() { + if s.Addr != "" { + panic("Server already started") + } + + s.Addr = s.Listener.Addr().String() + s.srv = new(tcp.Server) + *s.srv = *s.Config + s.srv.Addr = s.Addr + + cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) + if err != nil { + panic(fmt.Sprintf("tcptest: NewTLSServer: %v", err)) + } + + existingConfig := s.TLS + if existingConfig != nil { + s.TLS = existingConfig.Clone() + } else { + s.TLS = new(tls.Config) + } + if len(s.TLS.Certificates) == 0 { + s.TLS.Certificates = []tls.Certificate{cert} + } + s.Listener = tls.NewListener(s.Listener, s.TLS) + go s.srv.Serve(s.Listener) +} + +func (s *Server) Close() error { + if s.Addr == "" { + panic("Server not started") + } + return s.srv.Close() +} + +func NewServer(h tcp.Handler) *Server { + srv := NewUnstartedServer(h) + srv.Start() + return srv +} + +func NewTLSServer(h tcp.Handler) *Server { + srv := NewUnstartedServer(h) + srv.StartTLS() + return srv +} + +func NewUnstartedServer(h tcp.Handler) *Server { + return &Server{ + Listener: newLocalListener(), + Config: &tcp.Server{Handler: h}, + } +} + +func newLocalListener() net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + l, err = net.Listen("tcp6", "[::1]:0") + if err != nil { + panic("tcptest: Failed to listen on a port: " + err.Error()) + } + } + return l +} diff --git a/proxy/tcp_integration_test.go b/proxy/tcp_integration_test.go new file mode 100644 index 000000000..1d677adbf --- /dev/null +++ b/proxy/tcp_integration_test.go @@ -0,0 +1,109 @@ +package proxy + +import ( + "bufio" + "bytes" + "crypto/tls" + "crypto/x509" + "net" + "testing" + + "github.com/eBay/fabio/config" + "github.com/eBay/fabio/proxy/internal" + "github.com/eBay/fabio/proxy/tcp" + "github.com/eBay/fabio/proxy/tcp/tcptest" +) + +var echoHandler tcp.HandlerFunc = func(c net.Conn) error { + defer c.Close() + line, _, err := bufio.NewReader(c).ReadLine() + if err != nil { + return err + } + line = append(line, []byte(" echo")...) + _, err = c.Write(line) + return err +} + +func TestTCPProxy(t *testing.T) { + srv := tcptest.NewServer(echoHandler) + defer srv.Close() + + // start proxy + proxyAddr := "127.0.0.1:57778" + go func() { + h := &tcp.Proxy{ + Lookup: func(string) string { return srv.Addr }, + } + l := config.Listen{Addr: proxyAddr} + if err := ListenAndServeTCP(l, h); err != nil { + t.Log("ListenAndServeTCP: ", err) + } + }() + defer Close() + + // connect to proxy + out, err := tcptest.NewRetryDialer().Dial("tcp", proxyAddr) + if err != nil { + t.Fatalf("net.Dial: %#v", err) + } + defer out.Close() + + testRoundtrip(t, out) +} + +func TestTCPSNIProxy(t *testing.T) { + srv := tcptest.NewTLSServer(echoHandler) + defer srv.Close() + + // start tcp proxy + proxyAddr := "127.0.0.1:57778" + go func() { + h := &tcp.SNIProxy{ + Lookup: func(string) string { return srv.Addr }, + } + l := config.Listen{Addr: proxyAddr} + if err := ListenAndServeTCP(l, h); err != nil { + t.Log("ListenAndServeTCP: ", err) + } + }() + defer Close() + + rootCAs := x509.NewCertPool() + if ok := rootCAs.AppendCertsFromPEM(internal.LocalhostCert); !ok { + t.Fatal("could not parse cert") + } + cfg := &tls.Config{ + RootCAs: rootCAs, + ServerName: "example.com", + } + + // connect to proxy + out, err := tls.Dial("tcp", proxyAddr, cfg) + if err != nil { + t.Fatalf("net.Dial: %#v", err) + } + defer out.Close() + + testRoundtrip(t, out) +} + +func testRoundtrip(t *testing.T, c net.Conn) { + // send data to server + _, err := c.Write([]byte("foo\n")) + if err != nil { + t.Fatal("out.Write: ", err) + } + + // read response which should be + // src data + " echo" + line, _, err := bufio.NewReader(c).ReadLine() + if err != nil { + t.Fatal("readLine: ", err) + } + + // compare + if got, want := line, []byte("foo echo"); !bytes.Equal(got, want) { + t.Fatalf("got %q want %q", got, want) + } +} diff --git a/registry/consul/parse.go b/registry/consul/parse.go index 619524a36..e34caa755 100644 --- a/registry/consul/parse.go +++ b/registry/consul/parse.go @@ -32,6 +32,12 @@ func parseURLPrefixTag(s, prefix string, env map[string]string) (route, opts str } s = p[0] + // prefix is ":port" + if strings.HasPrefix(s, ":") { + return s, opts, true + } + + // prefix is "host/path" p = strings.SplitN(s, "/", 2) if len(p) == 1 { log.Printf("[WARN] consul: Invalid %s tag %q - You need to have a trailing slash!", prefix, s)