diff --git a/.travis.yml b/.travis.yml index d804cfe6..1908cb1b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,10 +1,14 @@ language: go go: - - 1.7.x + - 1.8.x - 1.x install: make setup +sudo: required # required for docker service +services: + # Required for dockertest + - docker script: # check compilation on supported targets - GOOS=linux GOARCH=amd64 make binaries diff --git a/CHANGELOG.md b/CHANGELOG.md index c0210e47..927e6a8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,12 +6,16 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] ### Added +* Experimental HTTP2 support +* Experimental connection pool ### Changed +* The top-level listener at the wh-server level now only listens on TCP - allowing each handler control over TLS/SSH ### Removed ### Fixed +* Complies with bug introduces from https://github.com/golang/go/issues/19767 ## [0.5.35] - 2017-06-15 ### Added diff --git a/Makefile b/Makefile index a29943f7..669e53e8 100644 --- a/Makefile +++ b/Makefile @@ -72,11 +72,11 @@ build: ## build the go packages test: ## run tests, except integration tests @echo "🎈 $@" - @go test -parallel 8 ${RACE} $(filter-out ${INTEGRATION_PACKAGE},${PACKAGES}) + @go test ${RACE} $(filter-out ${INTEGRATION_PACKAGE},${PACKAGES}) integration: ## run integration tests @echo "🎈 $@" - @go test -parallel 8 ${RACE} ${INTEGRATION_PACKAGE} + @go test ${RACE} ${INTEGRATION_PACKAGE} FORCE: diff --git a/README.md b/README.md index 6eeb300b..e077bd03 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Wormhole is a reverse proxy that creates a secure tunnel between two endpoints. ## Compiling -**Wormhole requires Go1.7+** +**Wormhole requires Go1.8+** go get github.com/superfly/wormhole cd $GOPATH/src/github.com/superfly/wormhole diff --git a/cmd/wormhole/main.go b/cmd/wormhole/main.go index 3fdb081d..719afbaa 100644 --- a/cmd/wormhole/main.go +++ b/cmd/wormhole/main.go @@ -5,8 +5,8 @@ import ( "fmt" "net/http" - log "github.com/Sirupsen/logrus" "github.com/prometheus/client_golang/prometheus/promhttp" + log "github.com/sirupsen/logrus" "github.com/superfly/wormhole" "github.com/superfly/wormhole/config" ) diff --git a/config/config.go b/config/config.go index a7d07645..343f7000 100644 --- a/config/config.go +++ b/config/config.go @@ -7,8 +7,8 @@ import ( "os" bugsnag_hook "github.com/Shopify/logrus-bugsnag" - "github.com/Sirupsen/logrus" bugsnag "github.com/bugsnag/bugsnag-go" + "github.com/sirupsen/logrus" "github.com/spf13/viper" prefixed "github.com/x-cray/logrus-prefixed-formatter" ) @@ -178,6 +178,18 @@ func NewServerConfig() (*ServerConfig, error) { } cfg.TLSPrivateKey = tlsKey + tlsCert, err := ioutil.ReadFile(viper.GetString("tls_cert_file")) + if err != nil { + return nil, cfgErr(unsetEnvStr, "FLY_TLS_CERT_FILE") + } + cfg.TLSCert = tlsCert + case HTTP2: + tlsKey, err := ioutil.ReadFile(viper.GetString("tls_private_key_file")) + if err != nil { + return nil, cfgErr(unsetEnvStr, "FLY_TLS_PRIVATE_KEY_FILE") + } + cfg.TLSPrivateKey = tlsKey + tlsCert, err := ioutil.ReadFile(viper.GetString("tls_cert_file")) if err != nil { return nil, cfgErr(unsetEnvStr, "FLY_TLS_CERT_FILE") @@ -278,7 +290,14 @@ func NewClientConfig() (*ClientConfig, error) { Logger: logger, } - if protocol == TLS { + switch protocol { + case TLS: + tlsCert, err := ioutil.ReadFile(viper.GetString("tls_cert_file")) + if err != nil { + return nil, cfgErr(unsetEnvStr, "FLY_TLS_CERT_FILE") + } + shared.TLSCert = tlsCert + case HTTP2: tlsCert, err := ioutil.ReadFile(viper.GetString("tls_cert_file")) if err != nil { return nil, cfgErr(unsetEnvStr, "FLY_TLS_CERT_FILE") @@ -351,6 +370,8 @@ const ( TCP // TLS connection pool TLS + // HTTP2 connection pool + HTTP2 _ _ _ @@ -367,6 +388,8 @@ func ParseTunnelProto(proto string) TunnelProto { return TCP case "tls": return TLS + case "http2": + return HTTP2 default: return UNSUPPORTED } diff --git a/glide.lock b/glide.lock index 708f859e..0fdcf9cd 100644 --- a/glide.lock +++ b/glide.lock @@ -1,8 +1,8 @@ -hash: 27d87619fc240ef63a996195f99ae03430cf74bfea1546bcf4ea12c0f61f5f42 -updated: 2017-04-27T17:46:19.870683507-07:00 +hash: 3e702f3a69224da236dac6b56faf98f63c3fdd4fc12d9d08409c303215fa3599 +updated: 2017-10-03T15:43:17.102621816-07:00 imports: - name: github.com/ant0ine/go-json-rest - version: 4602b00d2caab423578a3094c68137dcc1eb2051 + version: ebb33769ae013bd5f518a8bac348c310dea768b8 subpackages: - rest - rest/trie @@ -11,24 +11,24 @@ imports: subpackages: - quantile - name: github.com/bugsnag/bugsnag-go - version: 0e0aea7fe7d3fa47be41fc8a0e3a92d9668a9020 + version: 036f1af2a63f8133e596d1c127c86100b4642ba1 subpackages: - errors - name: github.com/bugsnag/panicwrap - version: aa7703c9414b36d4e9b2e42e6c704d0bfae7db64 + version: dd8df9a3778aaebc569794383e5c4ce87d6fd89e - name: github.com/fsnotify/fsnotify - version: a904159b9206978bb6d53fcc7a769e5cd726c737 + version: 4da3e2cfbabc9f751898f250b49f2439785783a1 - name: github.com/garyburd/redigo version: 8873b2f1995f59d4bcdd2b0dc9858e2cb9bf0c13 subpackages: - internal - redis - name: github.com/golang/protobuf - version: 69b215d01a5606c843240eab4937eab3acee6530 + version: 130e6b02ab059e7b717a096f397c5b60111cae74 subpackages: - proto - name: github.com/hashicorp/hcl - version: 630949a3c5fa3c613328e1b8256052cbc2327c9b + version: 68e816d1c783414e79bc65b3994d9ab6b0a722ab subpackages: - hcl/ast - hcl/parser @@ -38,32 +38,38 @@ imports: - json/parser - json/scanner - json/token +- name: github.com/jbenet/go-context + version: d14ea06fba99483203c19d92cfcd13ebe73135f4 + subpackages: + - io - name: github.com/jpillora/backoff - version: f24585d1c70490c0920ab34924f54f726e1416c7 + version: 8eab2debe79d12b7bd3d10653910df25fa9552ba - name: github.com/kardianos/osext - version: 9b883c5eb462dd5cb1b0a7a104fe86bc6b9bd391 + version: ae77be60afb1dcacde03767a8c37337fad28ac14 - name: github.com/klauspost/cpuid version: 09cded8978dc9e80714c4d85b0322337b0a1e5e0 - name: github.com/magiconair/properties - version: b3b15ef068fd0b17ddf408a23669f20811d194d2 + version: 8d7837e64d3c1ee4e54a880c5a920ab4316fc90a - name: github.com/mattn/go-colorable - version: 5411d3eea5978e6cdc258b30de592b60df6aba96 + version: ad5389df28cdac544c99bd7b9161a0b5b6ca9d1b - name: github.com/mattn/go-isatty - version: dda3de49cbfcec471bd7a70e6cc01fcc3ff90109 + version: a5cdd64afdee435007ee3e9f6ed4684af949d568 - name: github.com/matttproud/golang_protobuf_extensions version: c12348ce28de40eed0136aa2b644d0ee0650e56c subpackages: - pbutil - name: github.com/mgutz/ansi version: 9520e82c474b0a04dd04f8a40959027271bab992 +- name: github.com/mitchellh/go-homedir + version: b8bc1bf767474819792c23f32d8286a45736f1c6 - name: github.com/mitchellh/mapstructure - version: db1efb556f84b25a0a13a04aad883943538ad2e0 + version: d0303fe809921458f417bcf828397a65db30a7e4 - name: github.com/patrickmn/go-cache - version: 7ac151875ffb48b9f3ccce9ea20f020b0c1596c8 -- name: github.com/pelletier/go-buffruneio - version: df1e16fde7fc330a0ca68167c23bf7ed6ac31d6d + version: a3647f8e31d79543b2d0f0ae2fe5c379d72cedc0 - name: github.com/pelletier/go-toml - version: 22139eb5469018e7374b3e7ef653de37ffb44f72 + version: 2009e44b6f182e34d8ce081ac2767622937ea3d4 +- name: github.com/pkg/errors + version: 2b3a18b5f0fb6b4f9190549597d3f962c02bc5eb - name: github.com/prometheus/client_golang version: c5b7fccd204277076155f10851dad72b76a49317 subpackages: @@ -74,37 +80,37 @@ imports: subpackages: - go - name: github.com/prometheus/common - version: 49fee292b27bfff7f354ee0f64e1bc4850462edf + version: 2f17f4a9d485bf34b4bfaccc273805040e4f86c8 subpackages: - expfmt - internal/bitbucket.org/ww/goautoneg - model - name: github.com/prometheus/procfs - version: a1dba9ce8baed984a2495b658c82687f8157b98f + version: e645f4e5aaa8506fc71d6edbc5c4ff02c04c46f2 subpackages: - xfs - name: github.com/rs/xid - version: e959e92539c364578e455189c9250311a5160095 + version: 02dd45c33376f85d1064355dc790dcc4850596b1 - name: github.com/sergi/go-diff - version: 24e2351369ec4949b2ed0dc5c477afdd4c4034e8 + version: feef008d51ad2b3778f85d387ccf91735543008d subpackages: - diffmatchpatch - name: github.com/Shopify/logrus-bugsnag - version: 797fa877e4ab814c9a0e3fc1c77bf6e3beb76465 -- name: github.com/Sirupsen/logrus - version: ba1b36c82c5e05c4f912a88eab0dcd91a171688f + version: 6dbc35f2c30d1e37549f9673dd07912452ab28a5 +- name: github.com/sirupsen/logrus + version: f006c2ac4710855cf0f916dd6b77acf6b048dc6e - name: github.com/spf13/afero - version: 9be650865eab0c12963d8753212f4f9c66cdcf12 + version: 8a6ade7159a9b7fff5d2feac091c425177ac1b28 subpackages: - mem - name: github.com/spf13/cast - version: d1139bab1c07d5ad390a65e7305876b3c1a8370b + version: acbeb36b902d72a7a4c18e8f3241075e7ab763e4 - name: github.com/spf13/jwalterweatherman - version: fa7ca7e836cf3a8bb4ebf799f472c12d7e903d66 + version: 12bd96e66386c1960ab0f74ced1362f66f552f7b - name: github.com/spf13/pflag - version: 9ff6c6923cfffbcd502984b8e0c80539a94968b7 + version: be7121dd7a937a85e1e4b1ddda6a3edce3466110 - name: github.com/spf13/viper - version: 7538d73b4eb9511d85a9f1dfef202eeb8ac260f4 + version: d9cca5ef33035202efb1586825bdbb15ff9ec3ba - name: github.com/src-d/gcfg version: f187355171c936ac84a82793659ebb4936bc1c23 subpackages: @@ -112,32 +118,44 @@ imports: - token - types - name: github.com/ulule/limiter - version: c242da0b4c9524723c5a2dc8e7d49c228d1bb33c + version: 619f3ae8cc00f54934d27591e7010c8f6216c5ca - name: github.com/x-cray/logrus-prefixed-formatter - version: 9cd0cf058806896bf8ca66f9a957d36bd9e4ce0f + version: bb2702d423886830dee131692131d35648c382e2 +- name: github.com/xanzy/ssh-agent + version: ba9c9e33906f58169366275e3450db66139a31a9 - name: golang.org/x/crypto - version: 453249f01cfeb54c3d549ddb75ff152ca243f9d8 + version: 9419663f5a44be8b34ca85f08abc5fe1be11f8a3 subpackages: - curve25519 - ed25519 - ed25519/internal/edwards25519 - ssh - ssh/agent + - ssh/knownhosts + - ssh/terminal - name: golang.org/x/net - version: 6b27048ae5e6ad1ef927e72e437531493de612fe + version: 0a9397675ba34b2845f758fe3cd68828369c6517 subpackages: - context + - context/ctxhttp + - http2 + - http2/hpack + - idna + - lex/httplex - name: golang.org/x/sys - version: 075e574b89e4c2d22f2286a7e2b919519c6f3547 + version: 314a259e304ff91bd6985da2a7149bbf91237993 subpackages: - unix + - windows - name: golang.org/x/text - version: 85c29909967d7f171f821e7a42e7b7af76fb9598 + version: 1cbadb444a806fd9430d14ad08967ed91da4fa0a subpackages: + - secure/bidirule - transform + - unicode/bidi - unicode/norm - name: google.golang.org/appengine - version: 2e4a801b39fc199db615bfca7d0b9f8cd9580599 + version: 24e4144ec923c2374f6b06610c0df16a9222c3d9 subpackages: - datastore - internal @@ -147,12 +165,15 @@ imports: - internal/log - internal/modules - internal/remote_api -- name: gopkg.in/src-d/go-billy.v2 - version: 99d839800b93542496ec19319f37abacf17f589f +- name: gopkg.in/src-d/go-billy.v3 + version: c329b7bc7b9d24905d2bc1b85bfa29f7ae266314 subpackages: + - helper/chroot + - helper/polyfill - osfs + - util - name: gopkg.in/src-d/go-git.v4 - version: 36c78b9d1b1eea682703fb1cbb0f4f3144354389 + version: f9879dd043f84936a1f8acb8a53b74332a7ae135 subpackages: - config - internal/revision @@ -160,6 +181,8 @@ imports: - plumbing/cache - plumbing/filemode - plumbing/format/config + - plumbing/format/diff + - plumbing/format/gitignore - plumbing/format/idxfile - plumbing/format/index - plumbing/format/objfile @@ -187,6 +210,8 @@ imports: - utils/diff - utils/ioutil - utils/merkletrie + - utils/merkletrie/filesystem + - utils/merkletrie/index - utils/merkletrie/internal/frame - utils/merkletrie/noder - name: gopkg.in/vmihailenco/msgpack.v2 @@ -196,11 +221,76 @@ imports: - name: gopkg.in/warnings.v0 version: 8a331561fe74dadba6edfc59f3be66c22c3b065d - name: gopkg.in/yaml.v2 - version: a3f3340b5840cee44f372bddb5880fcbc419b46a + version: eb3733d160e74a9c7e442f435eb3bea458e1d19f testImports: - name: github.com/alicebob/miniredis - version: 10ddf01f45bee3c40d1af5dbaad9aa71e6f20835 -- name: github.com/bsm/redeo - version: 1ce09fc76693fb3c1ca9b529c66f38920beb6fb8 + version: 70cb3fab2cf658a95f2ccae5bfb2f0ccbba0c548 + subpackages: + - server +- name: github.com/Azure/go-ansiterm + version: d6e3b3328b783f23731bc4d058875b0371ff8109 + subpackages: + - winterm +- name: github.com/cenk/backoff + version: 32cd0c5b3aef12c76ed64aaf678f6c79736be7dc +- name: github.com/davecgh/go-spew + version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9 + subpackages: + - spew +- name: github.com/docker/docker + version: 89658bed64c2a8fe05a978e5b87dbec409d57a0f + subpackages: + - api/types + - api/types/blkiodev + - api/types/container + - api/types/filters + - api/types/mount + - api/types/network + - api/types/registry + - api/types/strslice + - api/types/swarm + - api/types/versions + - opts + - pkg/archive + - pkg/fileutils + - pkg/homedir + - pkg/idtools + - pkg/ioutils + - pkg/jsonlog + - pkg/jsonmessage + - pkg/longpath + - pkg/pools + - pkg/promise + - pkg/stdcopy + - pkg/system + - pkg/term + - pkg/term/windows +- name: github.com/docker/go-connections + version: 3ede32e2033de7505e6500d6c868c2b9ed9f169d + subpackages: + - nat +- name: github.com/docker/go-units + version: 0dadbb0345b35ec7ef35e228dabb8de89a65bf52 +- name: github.com/fsouza/go-dockerclient + version: 199e3d903f173ca5869445d99a5aebe85872a7a4 +- name: github.com/Microsoft/go-winio + version: 78439966b38d69bf38227fbf57ac8a6fee70f69a +- name: github.com/Nvveen/Gotty + version: cd527374f1e5bff4938207604a14f2e38a9cf512 +- name: github.com/opencontainers/runc + version: 0351df1c5a66838d0c392b4ac4cf9450de844e2d + subpackages: + - libcontainer/system + - libcontainer/user +- name: github.com/pmezard/go-difflib + version: d8ed2627bdf02c080bf22230dbb337003b7aba2d + subpackages: + - difflib +- name: github.com/stretchr/testify + version: 890a5c3458b43e6104ff5da8dfa139d013d77544 subpackages: - - info + - assert +- name: github.com/superfly/tlstest + version: 688815416c7e3edb2eb09d9dc14a31f9cb388cec +- name: gopkg.in/ory-am/dockertest.v3 + version: a7951f7a8442f0e70d36e499ed4d744f00af2963 diff --git a/glide.yaml b/glide.yaml index 90d14304..6bd89c1d 100644 --- a/glide.yaml +++ b/glide.yaml @@ -1,8 +1,6 @@ package: github.com/superfly/wormhole import: - package: github.com/jpillora/backoff -- package: github.com/Sirupsen/logrus - version: ~0.11.2 - package: github.com/x-cray/logrus-prefixed-formatter - package: github.com/garyburd/redigo version: ~1.0.0 @@ -22,5 +20,7 @@ import: subpackages: - prometheus - package: github.com/ulule/limiter +- package: github.com/sirupsen/logrus + version: ^1.0.3 testImport: - package: github.com/alicebob/miniredis diff --git a/integration/integration_test.go b/integration/integration_test.go index 50e1c754..e8e67d7c 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -7,7 +7,7 @@ import ( "runtime" "testing" - "github.com/Sirupsen/logrus" + "github.com/sirupsen/logrus" ) var showTrace = flag.Bool("show-trace", false, "show stack trace after tests finish") diff --git a/integration_test.go b/integration_test.go index 89106728..755f1796 100644 --- a/integration_test.go +++ b/integration_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/Sirupsen/logrus" "github.com/alicebob/miniredis" + "github.com/sirupsen/logrus" "github.com/superfly/wormhole" "github.com/superfly/wormhole/config" diff --git a/local.go b/local.go index 49d347cc..eb696b67 100644 --- a/local.go +++ b/local.go @@ -6,8 +6,8 @@ import ( "strings" "time" - "github.com/Sirupsen/logrus" "github.com/jpillora/backoff" + "github.com/sirupsen/logrus" "github.com/superfly/wormhole/config" "github.com/superfly/wormhole/local" @@ -41,6 +41,12 @@ func StartLocal(cfg *config.ClientConfig) { if err != nil { log.Fatal(err) } + case config.HTTP2: + handler, err = local.NewHTTP2Handler(cfg, release) + if err != nil { + log.Fatal(err) + } + default: log.Fatal("Unknown wormhole transport layer protocol selected.") } diff --git a/local/http2_handler.go b/local/http2_handler.go new file mode 100644 index 00000000..279fa43c --- /dev/null +++ b/local/http2_handler.go @@ -0,0 +1,287 @@ +package local + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "net/http" + "sync/atomic" + "time" + + "github.com/sirupsen/logrus" + "github.com/superfly/wormhole/config" + "github.com/superfly/wormhole/messages" + wnet "github.com/superfly/wormhole/net" + "golang.org/x/net/http2" +) + +// HTTP2Handler type represents the handler that opens a TCP conn to wormhole server and serves +// incoming requests +type HTTP2Handler struct { + RemoteEndpoint string + LocalEndpoint string + FlyToken string + Release *messages.Release + Version string + ln net.Listener + control net.Conn + conns []net.Conn + server *http2.Server + transport *http2.Transport + fClient *http.Client + tlsConfig *tls.Config + lastPongAt int64 + logger *logrus.Entry +} + +// NewHTTP2Handler returns a HTTP2Handler struct with TLS encryption +func NewHTTP2Handler(cfg *config.ClientConfig, release *messages.Release) (*HTTP2Handler, error) { + rootCAs := x509.NewCertPool() + ok := rootCAs.AppendCertsFromPEM(cfg.TLSCert) + if !ok { + return nil, fmt.Errorf("couln't append a root CA") + } + + tlsHost, _, err := net.SplitHostPort(cfg.RemoteEndpoint) + if err != nil { + return nil, err + } + + h := &HTTP2Handler{ + FlyToken: cfg.Token, + RemoteEndpoint: cfg.RemoteEndpoint, + LocalEndpoint: cfg.LocalEndpoint, + Release: release, + Version: cfg.Version, + tlsConfig: &tls.Config{RootCAs: rootCAs, ServerName: tlsHost}, + server: &http2.Server{}, + transport: &http2.Transport{}, + fClient: &http.Client{}, + logger: cfg.Logger.WithFields(logrus.Fields{"prefix": "HTTP2Handler"}), + } + + return h, nil +} + +// ListenAndServe accepts requests coming from wormhole server +// and forwards them to the local server +func (s *HTTP2Handler) ListenAndServe() error { + control, err := s.dialControl() + if err != nil { + return err + } + defer control.Close() + + s.control = control + ctlAuthMsg := &messages.AuthControl{ + Token: s.FlyToken, + } + buf, err := messages.Pack(ctlAuthMsg) + if err != nil { + return fmt.Errorf("error packing message to control: " + err.Error()) + } + + _, err = s.control.Write(buf) + if err != nil { + return fmt.Errorf("error writing to control: " + err.Error()) + } + + s.lastPongAt = time.Now().UnixNano() + go s.heartbeat() + + b := make([]byte, 1024) + for { + nr, err := s.control.Read(b) + if err == io.EOF { + continue + } + if err != nil { + return fmt.Errorf("error reading from control: " + err.Error()) + } + msg, err := messages.Unpack(b[:nr]) + if err != nil { + return fmt.Errorf("error parsing message from stream: " + err.Error()) + } + switch m := msg.(type) { + case *messages.OpenTunnel: + s.logger.Debug("Received Open Tunnel message.") + tcpConn, err := s.dial() + if err != nil { + return err + } + genericTLSConn, err := s.genericTLSWrap(tcpConn) + if err != nil { + return err + } + authMsg := &messages.AuthTunnel{ClientID: m.ClientID, Token: s.FlyToken} + b, _ := messages.Pack(authMsg) + _, err = genericTLSConn.Write(b) + if err != nil { + return fmt.Errorf("Failed to auth tunnel: %s", err.Error()) + } + // TODO: Listen for auth ACK + + http2TLSConn, err := s.http2ALPNTLSWrap(tcpConn) + if err != nil { + return err + } + + s.logger.Infof("Established TLS connection for Session: %s", m.ClientID) + s.conns = append(s.conns, http2TLSConn) + + opts := &http2.ServeConnOpts{ + // We are our own handler + Handler: s, + } + s.logger.Info("Serving http2 Connection") + go s.server.ServeConn(http2TLSConn, opts) + case *messages.Shutdown: + s.logger.Debugf("Received Shutdown message: %s", m.Error) + return s.Close() + case *messages.Pong: + atomic.StoreInt64(&s.lastPongAt, time.Now().UnixNano()) + default: + s.logger.Warn("Unrecognized command. Ignoring.") + } + } +} + +func (s *HTTP2Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + r.URL.Host = s.LocalEndpoint + // TODO: Figure out support for https + r.URL.Scheme = "http" + r.Host = s.LocalEndpoint + r.RequestURI = "" + + // We ignore the error ONLY because it will be forwarded + // to the end user + // TODO: Handle this error + resp, err := s.fClient.Do(r) + if err != nil { + s.logger.Error(err) + } + + // Delete this so we don't copy it over + // Will be handled by http.ResponseWriter + resp.Header.Del("Content-Length") + + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + w.WriteHeader(resp.StatusCode) + + defer resp.Body.Close() + + nr, err := io.Copy(w, resp.Body) + if err != nil { + s.logger.Errorf("Could not copy response body") + return + } + s.logger.Infof("Copied %d bytes between connection bodies", nr) +} + +// Close closes the listener and TCP connection +func (s *HTTP2Handler) Close() error { + err := s.control.Close() + if err != nil { + s.logger.Errorf("Control TCP conn close: %s", err) + } + for _, c := range s.conns { + err = c.Close() + if err != nil { + s.logger.Errorf("Proxy http2 conn close: %s", err) + } + } + return err +} + +// dial opens an unencrypted TCP connection to a server +func (s *HTTP2Handler) dial() (*net.TCPConn, error) { + conn, err := net.Dial("tcp", s.RemoteEndpoint) + if err != nil { + return nil, err + } + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return nil, errors.New("Error: could not cast tcp connection") + } + + return tcpConn, nil +} + +func (s *HTTP2Handler) dialControl() (net.Conn, error) { + conn, err := s.dial() + if err != nil { + return nil, err + } + + cConn, err := s.genericTLSWrap(conn) + if err != nil { + return nil, err + } + + return cConn, nil +} + +func (s *HTTP2Handler) genericTLSWrap(conn *net.TCPConn) (*tls.Conn, error) { + return wnet.GenericTLSWrap(conn, s.tlsConfig, tls.Client) +} + +// This wrapper fulfills the requirement for specifying the 'h2' ALPN TLS negotiation for +// TLS enabled http2 connections +// +// NOTE: The ALPN is a requirement of the spec for HTTP/2 capability discovery +// While technically the golang implementation will allow us not to perform ALPN, +// this breaks the http/2 spec. The goal here is to follow the RFC to the letter +// as documented in http://httpwg.org/specs/rfc7540.html#starting +func (s *HTTP2Handler) http2ALPNTLSWrap(conn *net.TCPConn) (*tls.Conn, error) { + return wnet.HTTP2ALPNTLSWrap(conn, s.tlsConfig, tls.Client) +} + +func (s *HTTP2Handler) heartbeat() { + // set lastPing to something sane + lastPing := time.Unix(atomic.LoadInt64(&s.lastPongAt)-1, 0) + ping := time.NewTicker(pingInterval) + pongCheck := time.NewTicker(time.Second) + + defer func() { + s.control.Close() + ping.Stop() + pongCheck.Stop() + }() + + for { + select { + case <-pongCheck.C: + lastPong := time.Unix(0, atomic.LoadInt64(&s.lastPongAt)) + needPong := lastPong.Sub(lastPing) < 0 + pongLatency := time.Since(lastPing) + + if needPong && pongLatency > maxPongLatency { + s.logger.Infof("Last ping: %v, Last pong: %v", lastPing, lastPong) + s.logger.Infof("Connection stale, haven't gotten PongMsg in %d seconds", int(pongLatency.Seconds())) + return + } + + case <-ping.C: + b, err := messages.Pack(&messages.Ping{}) + if err != nil { + s.logger.Errorf("Got error %v when creating PingMsg", err) + return + } + _, err = s.control.Write(b) + if err != nil { + s.logger.Errorf("Got error %v when writing PingMsg", err) + return + } + s.logger.Debug("Sent Ping message") + lastPing = time.Now() + } + } +} diff --git a/local/http2_handler_test.go b/local/http2_handler_test.go new file mode 100644 index 00000000..70e7841a --- /dev/null +++ b/local/http2_handler_test.go @@ -0,0 +1,243 @@ +package local + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/superfly/tlstest" + "github.com/superfly/wormhole/config" + "github.com/superfly/wormhole/messages" + wnet "github.com/superfly/wormhole/net" + "golang.org/x/net/http2" + + "os" +) + +var httpTestServer *httptest.Server +var testBody string + +var testTLSServerConfig *tls.Config +var testTLSClientConfig *tls.Config + +var testTLSCACert []byte + +var testRemoteListener *net.TCPListener + +func init() { + testBody = "test" + + httpTestServer = httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, testBody) + })) + + var testTLSServerCertPEM []byte + var testTLSServerKeyPEM []byte + var err error + + testTLSCACert, testTLSServerCertPEM, testTLSServerKeyPEM, err = tlstest.CreateServerCertKeyPEMPairWithRootCert() + if err != nil { + os.Exit(1) + } + + servTLSCert, err := tls.X509KeyPair(testTLSServerCertPEM, testTLSServerKeyPEM) + if err != nil { + os.Exit(1) + } + + testTLSServerConfig = &tls.Config{ + Certificates: []tls.Certificate{servTLSCert}, + } + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(testTLSCACert) + + testTLSClientConfig = &tls.Config{ + RootCAs: certPool, + ServerName: "127.0.0.1", + } + + tAddr, err := net.ResolveTCPAddr("tcp4", "127.0.0.1:8001") + if err != nil { + os.Exit(1) + } + + testRemoteListener, err = net.ListenTCP("tcp4", tAddr) + if err != nil { + os.Exit(1) + } +} + +func TestNewHTTP2Handler(t *testing.T) { + testCfg := &config.ClientConfig{ + Config: config.Config{ + Logger: logrus.New(), + Version: "test_version", + TLSCert: testTLSCACert, + }, + Token: "test_token", + LocalEndpoint: httpTestServer.Listener.Addr().String(), + RemoteEndpoint: testRemoteListener.Addr().String(), + } + + testRelease := &messages.Release{ + ID: "test_id", + } + + expHandler, err := NewHTTP2Handler(testCfg, testRelease) + assert.NoError(t, err, "Should be no error creating new handler") + + controlHandler := &HTTP2Handler{ + RemoteEndpoint: testRemoteListener.Addr().String(), + FlyToken: "test_token", + Version: "test_version", + LocalEndpoint: httpTestServer.Listener.Addr().String(), + tlsConfig: testTLSClientConfig, + } + + assert.Equal(t, expHandler.RemoteEndpoint, controlHandler.RemoteEndpoint, "Remote endpoints should match") + assert.Equal(t, expHandler.LocalEndpoint, controlHandler.LocalEndpoint, "Local endpoints should match") + assert.Equal(t, expHandler.Version, controlHandler.Version, "Versions should match") + assert.EqualValues(t, expHandler.tlsConfig, controlHandler.tlsConfig, "TLS Configs should match") +} + +func newTestHTTP2Handler() (*HTTP2Handler, error) { + testCfg := &config.ClientConfig{ + Config: config.Config{ + Logger: logrus.New(), + Version: "test_version", + TLSCert: testTLSCACert, + }, + Token: "test_token", + LocalEndpoint: httpTestServer.Listener.Addr().String(), + RemoteEndpoint: testRemoteListener.Addr().String(), + } + + testRelease := &messages.Release{ + ID: "test_id", + } + + return NewHTTP2Handler(testCfg, testRelease) +} + +func TestHTTP2Handler(t *testing.T) { + h, err := newTestHTTP2Handler() + assert.NoError(t, err, "Should be no error creating test handler") + + t.Run("Test_dial", func(t *testing.T) { + sConnCh := make(chan *net.TCPConn) + + go func(ln *net.TCPListener) { + s, err := ln.AcceptTCP() + assert.NoError(t, err, "Should be no error in accept") + sConnCh <- s + }(testRemoteListener) + + _, err = h.dial() + assert.NoError(t, err, "Should be no error getting connectio") + + _, ok := <-sConnCh + assert.True(t, ok, "We should have no issue getting conn") + }) + t.Run("Test_dial_control", func(t *testing.T) { + sConnCh := make(chan *tls.Conn) + + go func(ln *net.TCPListener) { + s, err := ln.AcceptTCP() + assert.NoError(t, err, "Should be no error in accept") + sTLS, err := wnet.GenericTLSWrap(s, testTLSServerConfig, tls.Server) + assert.NoError(t, err, "Should be no error wrapping tls conn") + sConnCh <- sTLS + }(testRemoteListener) + + _, err = h.dialControl() + assert.NoError(t, err, "Should be no error getting connectio") + + _, ok := <-sConnCh + assert.True(t, ok, "We should have no issue getting conn") + }) + t.Run("Test_listen_and_serve", func(t *testing.T) { + go func() { + for { + // We may have TCP errors to recover from + _ = h.ListenAndServe() + } + }() + controlConn, err := testRemoteListener.AcceptTCP() + assert.NoError(t, err, "Should have no error accepting control conn from handler") + + controlCTLS, err := wnet.GenericTLSWrap(controlConn, testTLSServerConfig, tls.Server) + assert.NoError(t, err, "Should have no error wrapping control conn with TLS") + + buf := make([]byte, 1024) + + nr, err := controlCTLS.Read(buf) + assert.NoError(t, err, "Should be no error reading from handler conn") + + msg, err := messages.Unpack(buf[:nr]) + assert.NoError(t, err, "Should be no error unpacking msg") + + authMsg, ok := msg.(*messages.AuthControl) + assert.True(t, ok, "Should be an AuthControl message") + + assert.Equal(t, authMsg.Token, h.FlyToken) + + t.Run("Test_open_tunnel", func(t *testing.T) { + oMsg := &messages.OpenTunnel{ + ClientID: "test", + } + + buf, err := messages.Pack(oMsg) + assert.NoError(t, err, "Should have no error packing messages") + + _, err = controlCTLS.Write(buf) + assert.NoError(t, err, "Should have no error writing message") + + tunConn, err := testRemoteListener.AcceptTCP() + assert.NoError(t, err, "Should have no error accepting tunnel") + + tunTLSConn, err := wnet.GenericTLSWrap(tunConn, testTLSServerConfig, tls.Server) + assert.NoError(t, err, "Should have no error wrapping tunnel") + + buf = make([]byte, 1024) + + nr, err := tunTLSConn.Read(buf) + assert.NoError(t, err, "Should have no error reading from tunnel conn") + + msg, err := messages.Unpack(buf[:nr]) + assert.NoError(t, err, "Should have no error unpacking message") + + authTunMsg, ok := msg.(*messages.AuthTunnel) + assert.True(t, ok, "Should be of type authtunnel") + + assert.Equal(t, authTunMsg.ClientID, oMsg.ClientID, "Should have same clientID as openTunnel") + assert.Equal(t, authTunMsg.Token, h.FlyToken, "Should have matching tokens") + + alpnConn, err := wnet.HTTP2ALPNTLSWrap(tunConn, testTLSServerConfig, tls.Server) + assert.NoError(t, err, "Should be no error wrapping alpnConn") + + tr := &http2.Transport{} + http2Client, err := tr.NewClientConn(alpnConn) + assert.NoError(t, err, "Should be no error creating new client conn") + + req, err := http.NewRequest("GET", "https://127.0.0.1:8000", nil) + assert.NoError(t, err, "Should have no error making request") + + resp, err := http2Client.RoundTrip(req) + assert.NoError(t, err, "Should have no error sending request") + + body, err := ioutil.ReadAll(resp.Body) + assert.NoError(t, err, "Should have no error reading body") + + assert.Equal(t, testBody, string(body)) + }) + }) +} diff --git a/local/ssh_handler.go b/local/ssh_handler.go index 06862786..e79f6c7a 100644 --- a/local/ssh_handler.go +++ b/local/ssh_handler.go @@ -10,7 +10,7 @@ import ( msgpack "gopkg.in/vmihailenco/msgpack.v2" - "github.com/Sirupsen/logrus" + "github.com/sirupsen/logrus" "github.com/superfly/wormhole/config" "github.com/superfly/wormhole/messages" wnet "github.com/superfly/wormhole/net" @@ -118,7 +118,8 @@ func (s *SSHHandler) dial() (*ssh.Client, net.Listener, error) { Auth: []ssh.AuthMethod{ ssh.Password(s.FlyToken), }, - Timeout: sshConnTimeout, + Timeout: sshConnTimeout, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), } // SSH into wormhole server diff --git a/local/tcp_handler.go b/local/tcp_handler.go index 154518ed..462db0c9 100644 --- a/local/tcp_handler.go +++ b/local/tcp_handler.go @@ -9,7 +9,7 @@ import ( "sync/atomic" "time" - "github.com/Sirupsen/logrus" + "github.com/sirupsen/logrus" "github.com/superfly/wormhole/config" "github.com/superfly/wormhole/messages" wnet "github.com/superfly/wormhole/net" diff --git a/net/conn_pool.go b/net/conn_pool.go new file mode 100644 index 00000000..49b2cc7e --- /dev/null +++ b/net/conn_pool.go @@ -0,0 +1,243 @@ +package net + +import ( + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "sync" + "sync/atomic" +) + +// connPool is designed to be a speedy connection pool +// Extensive testing for race conditions is a priority. +// DO NOT edit without updating or checking against existing tests +type connPool struct { + currentConn *connNode + numConns int64 + maxConns int + + logger *logrus.Entry + delConnCh chan *connNode + goodConnCh chan *connNode + insertedWhileNoneCh chan interface{} + waitingForConn int32 + + // Can't use RWMutex because even when reading we are modifying the currentConn + // Due to this we will never defer, since this is critical path code. Make SURE you unlock + // at every exit point. And attempt to have as few exit points as possible + sync.Mutex +} + +// ConnPoolObject represents an object to be handled by the connection pool +// NOTE: All functions implemented must be concurrency safe! +type ConnPoolObject interface { + // Close handles any cleanup required for the connection + // and is called whenever an object has been queued for deletion + Close() error + + // ShouldQueue indicates whether a ConnPoolObject should be queued in the connection pool + // An example here could be, if an conn-object can't be multiplexed, then return false + // whenever the object is in-use + // NOTE: ShouldQueue should resolve quickly, as it has the possibility to deadlock the queue + ShouldQueue() bool + + // ShouldDelete indicates whether a ConnPoolObject should be deleted from the pool + // An example here would be if an http2 connection runs out of streams + // NOTE: ShouldDelete should resolve quickly, as it has the possibility to deadlock the queue + ShouldDelete() bool +} + +// ConnPool is a fast concurrency safe connection pool structure +type ConnPool interface { + // Insert returns a no error false case only when we try to insert + // beyond our max connections limit + Insert(ConnPoolObject) (bool, error) + Get() ConnPoolObject +} + +// NewConnPool creates a new ConnPool +func NewConnPool(logger *logrus.Entry, maxConns int, initialConns []ConnPoolObject) (ConnPool, error) { + pool := &connPool{ + maxConns: maxConns, + logger: logger, + delConnCh: make(chan *connNode, maxConns), + goodConnCh: make(chan *connNode, maxConns), + insertedWhileNoneCh: make(chan interface{}), + waitingForConn: 0, + } + + for _, conn := range initialConns { + ok, err := pool.Insert(conn) + if !ok { + return nil, errors.New("Number of initialConns exceeds maxConns") + } + if err != nil { + return nil, errors.Wrapf(err, "Could not insert initial object: %v", conn) + } + } + + go pool.populateAvailable() + go pool.delLoop() + return pool, nil +} + +// connNode is a circulary linked list of connection +// instead of using an array where the earlier connections would +// get much higher load, we'll just run around the loop, so we balance +// in true round-robin fashion +type connNode struct { + prev *connNode + next *connNode + obj ConnPoolObject + + // deleted is not updated concurrently + // safe to manipulate unlocked + deleted uint32 +} + +func (pool *connPool) delLoop() { + for { + delConn := <-pool.delConnCh + // spawn a new goroutine even though + // delExistingConn immediately locks the pool + // because the go runtime does clever things to organize + // the goroutine mapping to help unlock as often as possible + go pool.delExistingConn(delConn) + } +} + +// delExistingConn deletes an connNode from the pool +// to be called after all requests/streams have been resolved +// or else the garbage collector could reap the connection before +// all data has been transferred +// the node MUST be in the list currently or be deleted +// NOTE: this is only to be called from the delete chan loop +func (pool *connPool) delExistingConn(hc *connNode) { + pool.Lock() + + // mark deleted in case + alreadyDeleted := !atomic.CompareAndSwapUint32(&hc.deleted, 0, 1) + if alreadyDeleted { + pool.Unlock() + pool.logger.Info("Caught multiple delete request") + return + } + + if atomic.LoadInt64(&pool.numConns) == 1 { + if err := hc.obj.Close(); err != nil { + pool.logger.Errorf("Error cleaning up connection object: %v+", err) + } + pool.currentConn = nil + atomic.StoreInt64(&pool.numConns, 0) + + pool.Unlock() + return + } + if hc == pool.currentConn { + pool.currentConn = pool.currentConn.next + } + hc.prev.next = hc.next + hc.next.prev = hc.prev + atomic.AddInt64(&pool.numConns, -1) + + pool.Unlock() + return +} + +// Insert adds a new conn to the end of the circulary linked list +func (pool *connPool) Insert(obj ConnPoolObject) (bool, error) { + pool.Lock() + // don't defer here. Defer has perf implications and this is critical path + + // don't insert a new connection when we've maxed out + if atomic.LoadInt64(&pool.numConns) >= int64(pool.maxConns) && pool.maxConns > 0 { + pool.Unlock() + return false, nil + } + + newConn := &connNode{ + obj: obj, + } + + if atomic.LoadInt64(&pool.numConns) == 0 { + newConn.next = newConn + newConn.prev = newConn + pool.currentConn = newConn + atomic.AddInt64(&pool.numConns, 1) + + if atomic.CompareAndSwapInt32(&pool.waitingForConn, 1, 0) { + pool.insertedWhileNoneCh <- struct{}{} + } + + pool.Unlock() + return true, nil + } + + newConn.next = pool.currentConn + newConn.prev = pool.currentConn.prev + pool.currentConn.prev.next = newConn + pool.currentConn.prev = newConn + atomic.AddInt64(&pool.numConns, 1) + + pool.Unlock() + return true, nil +} + +// populateAvailable is a run loop which constantly updates the channel of +// connections to be used +func (pool *connPool) populateAvailable() { + for { + + pool.Lock() + + if pool.currentConn == nil { + // ensure the Insert knows we're waiting for a connection + atomic.StoreInt32(&pool.waitingForConn, 1) + pool.Unlock() + + pool.logger.Warn("No connections in connection pool currently: waiting") + + // wait for insertion + //<-pool.insertedWhileNoneCh + <-pool.insertedWhileNoneCh + + pool.logger.Warn("No connections in connection pool currently: got new conn notification") + // TODO: add more connections in this case. Some sort of queue or signal system + continue + } + + if pool.currentConn.obj.ShouldQueue() { + retConn := pool.currentConn + + // iterate to next conn for next request + pool.currentConn = pool.currentConn.next + + select { + case pool.goodConnCh <- retConn: + default: + } + pool.Unlock() + continue + } else if pool.currentConn.obj.ShouldDelete() { + // mark for deletion + // NOTE: Once a Connection can no longer take a new request + // it never will be able to again. Therefore a conn will always go down the delete + // chan. Repeats down the chan are handled in the del method + pool.delConnCh <- pool.currentConn + // after marking for deletion, move on to next connection + pool.currentConn = pool.currentConn.next + } + + pool.Unlock() + } +} + +// Get returns a ConnPoolObject +// TODO: Allow set timeout +func (pool *connPool) Get() ConnPoolObject { + conn := <-pool.goodConnCh + for !conn.obj.ShouldQueue() { + // ensure state of conn hasn't changed before we return it + conn = <-pool.goodConnCh + } + return conn.obj +} diff --git a/net/conn_pool_test.go b/net/conn_pool_test.go new file mode 100644 index 00000000..431713e2 --- /dev/null +++ b/net/conn_pool_test.go @@ -0,0 +1,95 @@ +package net + +import ( + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "sync" + _ "sync/atomic" + "testing" + _ "time" + _ "unsafe" +) + +type testConnPoolObj struct { + canContinue bool + shouldQueue bool + sync.Mutex +} + +func (c *testConnPoolObj) ShouldQueue() bool { + return c.shouldQueue +} + +func (c *testConnPoolObj) Close() error { + return nil +} + +func (c *testConnPoolObj) ShouldDelete() bool { + return !c.canContinue +} + +func newBaseConnPool(initial []ConnPoolObject, max int) (ConnPool, error) { + logger := logrus.WithFields(logrus.Fields{"scope": "testing"}) + return NewConnPool(logger, max, initial) +} + +func TestInsert(t *testing.T) { + pool, err := newBaseConnPool([]ConnPoolObject{}, 10) + assert.NoError(t, err, "Should be no error creating pool") + + obj := &testConnPoolObj{ + canContinue: true, + shouldQueue: true, + } + ok, err := pool.Insert(obj) + assert.True(t, ok, "Should have room in pool") + assert.NoError(t, err, "Should have no error inserting into pool") + + objSame := pool.Get() + assert.Equal(t, obj, objSame, "With just one object in pool we should have only one come back") +} + +func TestInsertMulti(t *testing.T) { + logger := logrus.WithFields(logrus.Fields{"scope": "test_multi"}) + pool, err := newBaseConnPool([]ConnPoolObject{}, 2) + assert.NoError(t, err, "Should be no error creating pool") + + obj1 := &testConnPoolObj{ + canContinue: true, + shouldQueue: true, + } + logger.Info("Insert 1") + ok, err := pool.Insert(obj1) + assert.True(t, ok, "Should have room in pool") + assert.NoError(t, err, "Should have no error inserting into pool") + + obj2 := &testConnPoolObj{ + canContinue: true, + shouldQueue: true, + } + logger.Info("Insert 2") + ok, err = pool.Insert(obj2) + assert.True(t, ok, "Should have room in pool") + assert.NoError(t, err, "Should have no error inserting into pool") + + // Pool is not guaranteed to return in a particular order soon after inserting + // This we need to pull of 2xbuffer length to ensure we get the 2 we inserted + logger.Info("Get 1") + objGet1 := pool.Get() + + logger.Info("Get 2") + objGet2 := pool.Get() + + logger.Info("Get 3") + objGet3 := pool.Get() + + logger.Info("Get 4") + objGet4 := pool.Get() + + assert.True(t, objGet1 == obj1 || objGet1 == obj2, "Everything we get should be in the set we inserted-1") + assert.True(t, objGet2 == obj1 || objGet2 == obj2, "Everything we get should be in the set we inserted-2") + assert.True(t, objGet3 == obj1 || objGet3 == obj2, "Everything we get should be in the set we inserted-3") + assert.True(t, objGet4 == obj1 || objGet4 == obj2, "Everything we get should be in the set we inserted-4") + + assert.True(t, objGet1 != objGet2 || objGet1 != objGet3 || objGet1 != objGet4, "Use demorgan's to test for set completeness") +} diff --git a/net/utils.go b/net/utils.go index db3995c6..168509c0 100644 --- a/net/utils.go +++ b/net/utils.go @@ -1,9 +1,14 @@ package net import ( + "crypto/tls" + "fmt" + "golang.org/x/net/http2" "io" "net" + "reflect" "strings" + "time" ) // CopyDirection describes the direction of data copying in full-duplex link @@ -41,6 +46,100 @@ func (me multiError) Error() string { return errStr } +// TLSWrapperFunc represents a TLS Wrapper. This is intended to be either +// tls.Client or tls.Server see https://golang.org/pkg/crypto/tls/ for info +type TLSWrapperFunc func(conn net.Conn, cfg *tls.Config) *tls.Conn + +// GenericTLSWrap takes a TCP connection, a tls config, and an upgrade function +// and returns the new connection +func GenericTLSWrap(conn *net.TCPConn, cfg *tls.Config, tFunc TLSWrapperFunc) (*tls.Conn, error) { + var tConn *tls.Conn + + for { + if err := conn.SetDeadline(time.Now().Add(time.Second * 5)); err != nil { + return nil, err + } + + tConn = tFunc(conn, cfg) + + // check if the connection is upgraded before returning + // we want to catch the error early + if err := tConn.Handshake(); err != nil { + if netErr, ok := err.(net.Error); ok { + if netErr.Timeout() || netErr.Temporary() { + continue + } + } + return nil, err + } + if err := conn.SetDeadline(time.Time{}); err != nil { + return nil, err + } + break + } + + return tConn, nil +} + +// HTTP2ALPNTLSWrap returns a TLS connection that has been negotiated with `h2` ALPN +// tFunc must be either tls.Client or tls.Server. See https://golang.org/pkg/crypto/tls/ +// for proper usage of the tls.Config with either of these options +// +// NOTE: The ALPN is a requirement of the spec for HTTP/2 capability discovery +// While technically the golang implementation will allow us not to perform ALPN, +// this breaks the http/2 spec. The goal here is to follow the RFC to the letter +// as documented in http://httpwg.org/specs/rfc7540.html#starting +func HTTP2ALPNTLSWrap(conn *net.TCPConn, cfg *tls.Config, tFunc TLSWrapperFunc) (*tls.Conn, error) { + protoCfg := cfg.Clone() + // TODO: append here + protoCfg.NextProtos = []string{http2.NextProtoTLS} + + var tlsConn *tls.Conn + for { + if err := conn.SetDeadline(time.Now().Add(time.Second * 5)); err != nil { + return nil, err + } + tlsConn = tFunc(conn, protoCfg) + + if err := tlsConn.Handshake(); err != nil { + if netErr, ok := err.(net.Error); ok { + if netErr.Timeout() || netErr.Temporary() { + continue + } + } + return nil, err + } + if err := conn.SetDeadline(time.Time{}); err != nil { + return nil, err + } + break + } + + // Check if we're creating a client conn before checking verification + if isTLSClient(tFunc) { + if !protoCfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(protoCfg.ServerName); err != nil { + return nil, err + } + } + } + + state := tlsConn.ConnectionState() + if p := state.NegotiatedProtocol; p != http2.NextProtoTLS { + return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2.NextProtoTLS) + } + + if !state.NegotiatedProtocolIsMutual { + return nil, fmt.Errorf("http2: could not negotiate protocol mutually") + } + + return tlsConn, nil +} + +func isTLSClient(tFunc TLSWrapperFunc) bool { + return reflect.ValueOf(tFunc).Pointer() == reflect.ValueOf(tls.Client).Pointer() +} + // CopyCloseIO establishes a full-duplex link between 2 ReadWriteClosers func CopyCloseIO(lconn, rconn io.ReadWriteCloser) (lconnWritten, rconnWritten int64, err error) { copyCh := make(chan copyStatus, 2) diff --git a/process.go b/process.go index be8af4fe..affb7dfc 100644 --- a/process.go +++ b/process.go @@ -7,7 +7,7 @@ import ( "os/signal" "syscall" - "github.com/Sirupsen/logrus" + "github.com/sirupsen/logrus" ) // Process is a wrapper around external program diff --git a/release.go b/release.go index c0351095..036fa5af 100644 --- a/release.go +++ b/release.go @@ -50,7 +50,7 @@ func computeRelease(id, desc, branch string) (*messages.Release, error) { var branches []string refs.ForEach(func(ref *plumbing.Reference) error { - if ref.IsBranch() && head.Hash().String() == ref.Hash().String() { + if ref.Name().IsBranch() && head.Hash().String() == ref.Hash().String() { branch := strings.TrimPrefix(ref.Name().String(), "refs/heads/") branches = append(branches, branch) } diff --git a/remote.go b/remote.go index 798f32b4..cc3109d4 100644 --- a/remote.go +++ b/remote.go @@ -7,8 +7,8 @@ import ( "syscall" "time" - "github.com/Sirupsen/logrus" "github.com/garyburd/redigo/redis" + "github.com/sirupsen/logrus" "github.com/superfly/wormhole/config" handler "github.com/superfly/wormhole/remote" ) @@ -33,15 +33,16 @@ func StartRemote(cfg *config.ServerConfig) { if err != nil { log.Fatal(err) } - case config.TLS: - server.TLSCert = &cfg.TLSCert - server.TLSPrivateKey = &cfg.TLSPrivateKey - fallthrough case config.TCP: h, err = handler.NewTCPHandler(cfg, redisPool) if err != nil { log.Fatal(err) } + case config.HTTP2: + h, err = handler.NewHTTP2Handler(cfg, redisPool) + if err != nil { + log.Fatal(err) + } default: log.Fatal("Unknown wormhole transport layer protocol selected.") } diff --git a/remote/http2_handler.go b/remote/http2_handler.go new file mode 100644 index 00000000..c0c9a3f2 --- /dev/null +++ b/remote/http2_handler.go @@ -0,0 +1,177 @@ +package remote + +import ( + "crypto/tls" + "net" + + "github.com/garyburd/redigo/redis" + "github.com/sirupsen/logrus" + "github.com/superfly/wormhole/config" + "github.com/superfly/wormhole/messages" + wnet "github.com/superfly/wormhole/net" + "github.com/superfly/wormhole/session" +) + +// HTTP2Handler type represents the handler that accepts incoming wormhole connections +type HTTP2Handler struct { + nodeID string + localhost string + clusterURL string + sessions map[string]session.Session + pool *redis.Pool + logger *logrus.Entry + tlsConfig *tls.Config +} + +// NewHTTP2Handler ... +func NewHTTP2Handler(cfg *config.ServerConfig, pool *redis.Pool) (*HTTP2Handler, error) { + h := HTTP2Handler{ + nodeID: cfg.NodeID, + sessions: make(map[string]session.Session), + localhost: cfg.Localhost, + clusterURL: cfg.ClusterURL, + pool: pool, + logger: cfg.Logger.WithFields(logrus.Fields{"prefix": "HTTP2Handler"}), + } + + crt, err := tls.X509KeyPair(cfg.TLSCert, cfg.TLSPrivateKey) + if err != nil { + return nil, err + } + h.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{crt}, + } + return &h, nil +} + +// Serve accepts incoming wormhole connections and passes them to the handler +// We are explicit with the *net.TCPConn since we need to be this way - and let the handler and +// sessions handle wrapping in TLS. Having a TCPConn all the way down will highlight the dangers +// of sending data over the socket without first wrapping in TLS +func (h *HTTP2Handler) Serve(conn net.Conn) { + // TODO: Have remote only hand off *net.TCPConn + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + h.logger.Errorf("was not given TCP socket") + return + } + tlsConn, err := h.genericTLSWrap(tcpConn) + if err != nil { + h.logger.Errorf("error establishing tls session: " + err.Error()) + return + } + + buf := make([]byte, 1024) + + nr, err := tlsConn.Read(buf) + if err != nil { + h.logger.Errorf("error reading from stream: " + err.Error()) + return + } + msg, err := messages.Unpack(buf[:nr]) + if err != nil { + h.logger.Errorf("error parsing message from stream: " + err.Error()) + return + } + + switch m := msg.(type) { + case *messages.AuthControl: + go h.http2SessionHandler(tlsConn) + case *messages.AuthTunnel: + if sess, ok := h.sessions[m.ClientID]; !ok { + h.logger.Error("New tunnel conn not associated with any session. Closing") + tlsConn.Close() + } else { + // open a proxy conn on current session + h.logger.Debugf("Adding New tunnel conn to session: %s", sess.ID()) + http2Sess := sess.(*session.HTTP2Session) + + alpnConn, err := h.http2ALPNTLSWrap(tcpConn) + if err != nil { + h.logger.Errorf("Couldn't establish ALPN connection") + return + } + + if err := http2Sess.AddTunnel(alpnConn); err != nil { + h.logger.Errorf("Error establishing Tunnel: %v+", err) + } + h.logger.Debugf("Successfully Added New tunnel conn to session: %s", sess.ID()) + } + default: + h.logger.Error("unparsable response") + tlsConn.Close() + } +} + +func (h *HTTP2Handler) genericTLSWrap(conn *net.TCPConn) (*tls.Conn, error) { + return wnet.GenericTLSWrap(conn, h.tlsConfig, tls.Server) +} + +// NOTE: The ALPN is a requirement of the spec for HTTP/2 capability discovery +// While technically the golang implementation will allow us not to perform ALPN, +// this breaks the http/2 spec. The goal here is to follow the RFC to the letter +// as documented in http://httpwg.org/specs/rfc7540.html#starting +func (h *HTTP2Handler) http2ALPNTLSWrap(conn *net.TCPConn) (*tls.Conn, error) { + return wnet.HTTP2ALPNTLSWrap(conn, h.tlsConfig, tls.Server) +} + +// Close closes all sessions handled by HTTP2Handler +func (h *HTTP2Handler) Close() { + for _, sess := range h.sessions { + sess.Close() + delete(h.sessions, sess.ID()) + } +} + +func (h *HTTP2Handler) http2SessionHandler(conn net.Conn) { + args := &session.HTTP2SessionArgs{ + Logger: h.logger.Logger, + NodeID: h.nodeID, + RedisPool: h.pool, + Conn: conn, + TLSConfig: h.tlsConfig, + } + + sess, err := session.NewHTTP2Session(args) + if err != nil { + h.logger.WithField("client_addr", conn.RemoteAddr().String()).Errorln("error creating a session:", err) + return + } + h.sessions[sess.ID()] = sess + + if err := sess.RequireStream(); err != nil { + h.logger.WithField("client_addr", conn.RemoteAddr().String()).Errorln("error getting a stream:", err) + return + } + + if err := sess.RequireAuthentication(); err != nil { + h.logger.Errorln(err) + return + } + + defer h.closeSession(sess) + + ln, err := listenTCP("tcp_ingress", sess) + if err != nil { + h.logger.Errorln(err) + return + } + + _, port, _ := net.SplitHostPort(ln.Addr().String()) + sess.EndpointAddr = h.localhost + ":" + port + sess.ClusterURL = h.clusterURL + + if err := sess.RegisterEndpoint(); err != nil { + h.logger.Errorln("Error registering endpoint:", err) + return + } + + h.logger.Infof("Started session %s for %s (%s). Listening on: %s", sess.ID(), sess.NodeID(), sess.Client(), sess.Endpoint()) + + sess.HandleRequests(ln) +} + +func (h *HTTP2Handler) closeSession(sess session.Session) { + sess.Close() + delete(h.sessions, sess.ID()) +} diff --git a/remote/http2_handler_test.go b/remote/http2_handler_test.go new file mode 100644 index 00000000..dde7e324 --- /dev/null +++ b/remote/http2_handler_test.go @@ -0,0 +1,344 @@ +package remote + +import ( + "crypto/tls" + "crypto/x509" + + "fmt" + "net" + _ "net/http" + _ "net/http/httptest" + "net/url" + "os" + "testing" + "time" + + "github.com/garyburd/redigo/redis" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/superfly/tlstest" + "github.com/superfly/wormhole/config" + "github.com/superfly/wormhole/messages" + "github.com/superfly/wormhole/session" + "golang.org/x/net/http2" + "gopkg.in/ory-am/dockertest.v3" +) + +var redisPool *redis.Pool +var serverTLSConfig *tls.Config +var clientTLSConfig *tls.Config + +var serverTLSCert tls.Certificate +var serverCrtPEM []byte +var serverKeyPEM []byte + +func TestMain(m *testing.M) { + var rootCrtPEM []byte + var err error + rootCrtPEM, serverCrtPEM, serverKeyPEM, err = tlstest.CreateServerCertKeyPEMPairWithRootCert() + if err != nil { + log.Fatalf("tlstest could not generate x509 certs %v+", err) + } + + serverTLSCert, err = tls.X509KeyPair(serverCrtPEM, serverKeyPEM) + if err != nil { + log.Fatalf("Couldn't create tls cert from keypair %v+", err) + } + + serverTLSConfig = &tls.Config{ + Certificates: []tls.Certificate{serverTLSCert}, + } + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(rootCrtPEM) + + clientTLSConfig = &tls.Config{ + RootCAs: certPool, + ServerName: "127.0.0.1", + } + + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Dockertest could not connect to docker: %s", err) + } + + redisResource, err := pool.Run("redis", "4.0.1", []string{}) + if err != nil { + log.Fatalf("Could not create redis container") + } + + if err := pool.Retry(func() error { + var err error + c, err := redis.DialURL(fmt.Sprintf("redis://localhost:%s", redisResource.GetPort("6379/tcp"))) + if err != nil { + return err + } + _, err = c.Do("PING") + return err + }); err != nil { + log.Fatalf("Could not connect to redis container: %s", err) + } + + redisPool = newRedisPool(fmt.Sprintf("redis://localhost:%s", redisResource.GetPort("6379/tcp"))) + + code := m.Run() + + if err := pool.Purge(redisResource); err != nil { + log.Fatalf("Could not purge redis: %s", err) + } + + os.Exit(code) +} + +func newRedisPool(redisURL string) *redis.Pool { + return &redis.Pool{ + MaxIdle: 3, + IdleTimeout: 240 * time.Second, + Dial: func() (redis.Conn, error) { + conn, err := redis.DialURL(redisURL) + if err != nil { + return nil, err + } + + parsedURL, err := url.Parse(redisURL) + if err != nil { + return nil, err + } + if parsedURL.User != nil { + if password, hasPassword := parsedURL.User.Password(); hasPassword == true { + if _, authErr := conn.Do("AUTH", password); authErr != nil { + conn.Close() + return nil, authErr + } + } + } + return conn, nil + }, + TestOnBorrow: func(conn redis.Conn, t time.Time) error { + if time.Since(t) < time.Minute { + return nil + } + _, err := conn.Do("PING") + return err + }, + } +} + +func TestNewHTTP2Handler(t *testing.T) { + cfg := &config.ServerConfig{ + Config: config.Config{ + Logger: log.New(), + TLSCert: serverCrtPEM, + Localhost: "localhost", + }, + TLSPrivateKey: serverKeyPEM, + NodeID: "1", + ClusterURL: "localhost", + } + + h, err := NewHTTP2Handler(cfg, redisPool) + assert.NoError(t, err, "Should be no error creating http2 handler") + + hControl := &HTTP2Handler{ + tlsConfig: serverTLSConfig, + logger: cfg.Logger.WithFields(log.Fields{"prefix": "HTTP2Handler"}), + pool: redisPool, + localhost: "localhost", + clusterURL: "localhost", + sessions: make(map[string]session.Session), + nodeID: "1", + } + + assert.EqualValues(t, hControl, h, "Control and test HTTP2Handlers should match values") +} + +func newTestHTTP2Handler() (*HTTP2Handler, error) { + h := &HTTP2Handler{ + tlsConfig: serverTLSConfig, + logger: log.New().WithFields(log.Fields{"prefix": "TestHTTP2Handler"}), + pool: redisPool, + localhost: "localhost", + clusterURL: "localhost", + sessions: make(map[string]session.Session), + nodeID: "1", + } + + return h, nil +} + +func newServerClientTCPConns() (serverConn *net.TCPConn, clientConn *net.TCPConn, err error) { + sConnCh := make(chan *net.TCPConn) + lnAddr := &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 8085, + } + + ln, err := net.ListenTCP("tcp", lnAddr) + if err != nil { + log.Errorf("Error creating TCP listener: %+v", err) + return + } + + go func(ln *net.TCPListener) { + s, err := ln.AcceptTCP() + if err != nil { + log.Errorf("Error accepting TCP listener: %+v", err) + close(sConnCh) + } + if err := ln.Close(); err != nil { + log.Errorf("Error closing listener: %+v", err) + close(sConnCh) + } + sConnCh <- s + }(ln) + + clientConn, err = net.DialTCP("tcp", nil, lnAddr) + if err != nil { + log.Errorf("Error dialing TCP conn") + return + } + + serverConn, ok := <-sConnCh + if !ok { + return nil, nil, errors.New("Error creating server conn") + } + + return +} + +func TestWrapsTLSOnServe(t *testing.T) { + h, err := newTestHTTP2Handler() + assert.NoError(t, err, "Should be no error creating new HTTP2Handler") + assert.NotNil(t, h, "Handler shouldn't be nil") + + sConn, cConn, err := newServerClientTCPConns() + assert.NoError(t, err, "Should be no error creating server/client TCP conns") + assert.NotNil(t, sConn, "Server conn should not be nil") + assert.NotNil(t, cConn, "Client conn should not be nil") + + go h.Serve(sConn) + + tlsClientConn := tls.Client(cConn, clientTLSConfig) + + err = tlsClientConn.Handshake() + assert.NoError(t, err, "Should be no error completing tls handshake") +} + +func wrapClientConn(cConn *net.TCPConn, tlsConf *tls.Config, alpn bool) (*tls.Conn, error) { + + if alpn { + tlsConf = tlsConf.Clone() + // TODO: append here + tlsConf.NextProtos = []string{http2.NextProtoTLS} + + } + + var tlsClientConn *tls.Conn + + for { + if err := cConn.SetDeadline(time.Now().Add(time.Second * 1)); err != nil { + return nil, err + } + tlsClientConn = tls.Client(cConn, tlsConf) + + if err := tlsClientConn.Handshake(); err != nil { + if netErr, ok := err.(net.Error); ok { + if netErr.Timeout() { + continue + } + if netErr.Temporary() { + continue + } + } + return nil, err + } + + if err := cConn.SetDeadline(time.Time{}); err != nil { + return nil, err + } + break + } + if alpn { + if err := tlsClientConn.VerifyHostname(tlsConf.ServerName); err != nil { + return &tls.Conn{}, err + } + state := tlsClientConn.ConnectionState() + if p := state.NegotiatedProtocol; p != http2.NextProtoTLS { + return &tls.Conn{}, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2.NextProtoTLS) + } + + if !state.NegotiatedProtocolIsMutual { + return &tls.Conn{}, fmt.Errorf("http2: could not negotiate protocol mutually") + } + } + + return tlsClientConn, nil +} + +func TestCreatesFullSession(t *testing.T) { + h, err := newTestHTTP2Handler() + assert.NoError(t, err, "Should be no error creating new HTTP2Handler") + assert.NotNil(t, h, "Handler shouldn't be nil") + + sControlConn, cControlConn, err := newServerClientTCPConns() + assert.NoError(t, err, "Should be no error creating server/client TCP conns") + assert.NotNil(t, sControlConn, "Server conn should not be nil") + assert.NotNil(t, cControlConn, "Client conn should not be nil") + + go h.Serve(sControlConn) + + // Establish TLS connection + tlsCControlConn, err := wrapClientConn(cControlConn, clientTLSConfig, false) + assert.NoError(t, err, "Should have no error wrapping client in tls") + + authMessage := &messages.AuthControl{ + Token: "test", + } + + authData, err := messages.Pack(authMessage) + assert.NoError(t, err, "Should be no error packing message") + + _, err = tlsCControlConn.Write(authData) + assert.NoError(t, err, "Should be no error writing data") + + tunData := make([]byte, 1024) + nr, err := tlsCControlConn.Read(tunData) + assert.NoError(t, err, "Should be no error reading tunnel message") + + msg, err := messages.Unpack(tunData[:nr]) + assert.NoError(t, err, "Should be no error unpacking message") + + tunMessage, ok := msg.(*messages.OpenTunnel) + assert.True(t, ok, "Should be an opentunnel message") + + // establish new connections for tunnel socket + sTunnelConn, cTunnelConn, err := newServerClientTCPConns() + assert.NoError(t, err, "no error for conns") + + go h.Serve(sTunnelConn) + + tlsCTunnelConn, err := wrapClientConn(cTunnelConn, clientTLSConfig, false) + assert.NoError(t, err, "Should be no error wrapping client") + + authTunMessage := &messages.AuthTunnel{ + ClientID: tunMessage.ClientID, + } + + authTunData, err := messages.Pack(authTunMessage) + assert.NoError(t, err, "Should be no error packing authTunMessage") + + _, err = tlsCTunnelConn.Write(authTunData) + assert.NoError(t, err, "Should have no error writing to tunnel conn") + + cTunnelConn.SetDeadline(time.Now().Add(time.Second * 15)) + + alpnTunnelConn, err := wrapClientConn(cTunnelConn, clientTLSConfig, true) + + assert.NoError(t, err, "Should have no error establishing http2 conn tls") + + _ = alpnTunnelConn + + // TODO: Test throughput + // This is dependent on registering backend IDs with token upon creation like the SSH handler currently does +} diff --git a/remote/server.go b/remote/server.go index 2f58d61f..44e4a890 100644 --- a/remote/server.go +++ b/remote/server.go @@ -1,57 +1,46 @@ package remote import ( - "crypto/tls" "fmt" "net" - "github.com/Sirupsen/logrus" + "github.com/sirupsen/logrus" ) // Server contains configuration options for a TCP Server type Server struct { - TLSCert *[]byte - TLSPrivateKey *[]byte - Logger *logrus.Logger + Logger *logrus.Logger } // ListenAndServe accepts incoming wormhole connections and passes them to the handler func (s *Server) ListenAndServe(addr string, handler Handler) error { log := s.Logger.WithFields(logrus.Fields{"prefix": "Server"}) - listener, err := s.newListener(addr) + listener, err := s.newTCPListener(addr) if err != nil { return fmt.Errorf("Failed to listen on %s (%s)", addr, err.Error()) } for { - conn, err := listener.Accept() + tcpConn, err := listener.AcceptTCP() if err != nil { log.Errorf("Failed to accept wormhole connection (%s)", err.Error()) break } - log.Debugln("Accepted wormhole TCP conn from:", conn.RemoteAddr()) + log.Debugln("Accepted wormhole TCP conn from:", tcpConn.RemoteAddr()) - go handler.Serve(conn) + go handler.Serve(tcpConn) } return nil } -func (s *Server) newListener(addr string) (net.Listener, error) { - if s.encrypted() { - cert, err := tls.X509KeyPair(*s.TLSCert, *s.TLSPrivateKey) - if err != nil { - return nil, err - } - return tls.Listen("tcp", addr, &tls.Config{ - Certificates: []tls.Certificate{cert}, - }) +func (s *Server) newTCPListener(addr string) (*net.TCPListener, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err } - return net.Listen("tcp", addr) -} - -func (s *Server) encrypted() bool { - return s.TLSCert != nil && - s.TLSPrivateKey != nil && - len(*s.TLSCert) > 0 && - len(*s.TLSPrivateKey) > 0 + tcpLN, ok := ln.(*net.TCPListener) + if !ok { + return nil, fmt.Errorf("Could not create tcp listener") + } + return tcpLN, nil } diff --git a/remote/ssh_handler.go b/remote/ssh_handler.go index 9eaa9f42..a2d6a213 100644 --- a/remote/ssh_handler.go +++ b/remote/ssh_handler.go @@ -7,8 +7,8 @@ import ( "sync" "time" - "github.com/Sirupsen/logrus" "github.com/garyburd/redigo/redis" + "github.com/sirupsen/logrus" "github.com/superfly/wormhole/config" wnet "github.com/superfly/wormhole/net" "github.com/superfly/wormhole/session" diff --git a/remote/tcp_handler.go b/remote/tcp_handler.go index ca6204d9..1a945146 100644 --- a/remote/tcp_handler.go +++ b/remote/tcp_handler.go @@ -3,8 +3,8 @@ package remote import ( "net" - "github.com/Sirupsen/logrus" "github.com/garyburd/redigo/redis" + "github.com/sirupsen/logrus" "github.com/superfly/wormhole/config" "github.com/superfly/wormhole/messages" "github.com/superfly/wormhole/session" diff --git a/scripts/wormhole-local.sh b/scripts/wormhole-local.sh index 2bd54c12..f49e5852 100755 --- a/scripts/wormhole-local.sh +++ b/scripts/wormhole-local.sh @@ -48,7 +48,7 @@ register_client() { spawn_wormhole() { token=$1 - FLY_TOKEN=$token FLY_PORT=$PORT $GOPATH/src/github.com/superfly/wormhole/cmd/wormhole/wormhole & + FLY_TOKEN=$token FLY_PORT=$PORT $GOPATH/src/github.com/superfly/wormhole/bin/wormhole & CHILD_PIDS+=("$!") echo "DONE (PID: $!)" } diff --git a/scripts/wormhole-server.sh b/scripts/wormhole-server.sh index 7c00650e..915664a7 100755 --- a/scripts/wormhole-server.sh +++ b/scripts/wormhole-server.sh @@ -25,6 +25,6 @@ export FLY_TLS_CERT_FILE=$GOPATH/src/github.com/superfly/wormhole/scripts/cert.p export FLY_TLS_PRIVATE_KEY_FILE=$GOPATH/src/github.com/superfly/wormhole/scripts/key.pem -WORMHOLE_BIN=$GOPATH/src/github.com/superfly/wormhole/cmd/wormhole/wormhole +WORMHOLE_BIN=$GOPATH/src/github.com/superfly/wormhole/bin/wormhole $WORMHOLE_BIN -server diff --git a/session/http2_session.go b/session/http2_session.go new file mode 100644 index 00000000..fd6db8b7 --- /dev/null +++ b/session/http2_session.go @@ -0,0 +1,361 @@ +package session + +import ( + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "sync/atomic" + "time" + + "github.com/garyburd/redigo/redis" + "github.com/rs/xid" + "github.com/sirupsen/logrus" + "github.com/superfly/wormhole/messages" + wnet "github.com/superfly/wormhole/net" + + "golang.org/x/net/http2" +) + +// HTTP2Session extends information about connected client stored in Session. +// It also includes: +// - control connection for exchanging communication with the client +// - channel with available tunnel connections +// - timestamp with the last known ping from the client +type HTTP2Session struct { + baseSession + + control net.Conn + conns wnet.ConnPool + server *http.Server + transport *http2.Transport + + lastPingAt int64 +} + +// HTTP2SessionArgs defines the arguments to be passed to NewHTTP2Session +type HTTP2SessionArgs struct { + Logger *logrus.Logger + NodeID string + TLSConfig *tls.Config + RedisPool *redis.Pool + Conn net.Conn +} + +// NewHTTP2Session creates new TCPSession struct +func NewHTTP2Session(args *HTTP2SessionArgs) (*HTTP2Session, error) { + base := baseSession{ + id: xid.New().String(), + nodeID: args.NodeID, + store: NewRedisStore(args.RedisPool), + logger: args.Logger.WithFields(logrus.Fields{"prefix": "HTTP2Session"}), + } + s := &HTTP2Session{ + control: args.Conn, + baseSession: base, + transport: &http2.Transport{}, + lastPingAt: time.Now().UnixNano(), + } + + server := &http.Server{ + Handler: s, + TLSConfig: args.TLSConfig.Clone(), // Currently doesn't do anything since we listen with tcp + } + + if err := http2.ConfigureServer(server, &http2.Server{}); err != nil { + return nil, err + } + + s.server = server + + pool, err := wnet.NewConnPool( + args.Logger.WithFields(logrus.Fields{"prefix": "HTTP2ConnPool"}), + 10, + []wnet.ConnPoolObject{}) + if err != nil { + return nil, err + } + s.conns = pool + + return s, nil +} + +type http2Tunnel struct { + conn *tls.Conn + cc *http2.ClientConn + cStreams uint32 + maxCStreams uint32 +} + +func (c *http2Tunnel) Close() error { + return c.conn.Close() +} + +func (c *http2Tunnel) ShouldDelete() bool { + return !c.cc.CanTakeNewRequest() +} + +func (c *http2Tunnel) ShouldQueue() bool { + cs := atomic.LoadUint32(&c.cStreams) + + return cs <= c.maxCStreams +} + +func (c *http2Tunnel) incrementStreamCount() { + atomic.AddUint32(&c.cStreams, 1) +} + +func (c *http2Tunnel) decrementStreamCount() { + atomic.AddUint32(&c.cStreams, ^uint32(0)) +} + +func (c *http2Tunnel) rewriteRequest(r *http.Request) { + r.URL.Host = c.conn.RemoteAddr().String() + r.URL.Scheme = "https" + r.Host = c.conn.RemoteAddr().String() + r.RequestURI = "" +} + +func (c *http2Tunnel) RoundTrip(r *http.Request) (*http.Response, error) { + c.rewriteRequest(r) + + c.incrementStreamCount() + defer c.decrementStreamCount() + + return c.cc.RoundTrip(r) +} + +// AddTunnel adds a connection to the pool of tunnel connections +func (s *HTTP2Session) AddTunnel(conn *tls.Conn) error { + cc, err := s.transport.NewClientConn(conn) + if err != nil { + return err + } + + poolObj := &http2Tunnel{ + conn: conn, + cc: cc, + cStreams: 0, + maxCStreams: 10, + } + + ok, err := s.conns.Insert(poolObj) + if err != nil { + return err + } + if !ok { + s.logger.Warn("Connection pool is full while trying to add ClientConn") + } + + return nil +} + +// RequireStream sends a request to the client to open a new tunnel Connection +// for this Session. +func (s *HTTP2Session) RequireStream() error { + return s.openTunnel() +} + +// HandleRequests handles all requests coming over the control connection from the client. +// The main function is to accept ingress traffic (from the listener) once the remote port +// forwarding is set up. +// It also handles out-of-band communication, like the maintaining the Session heartbeat or +// request the client to open new tunnel connections. +func (s *HTTP2Session) HandleRequests(ln net.Listener) { + go s.controlLoop() + go s.heartbeat() + s.handleRemoteForward(ln) +} + +// RequireAuthentication registers the connection +// TODO: add authentication here +func (s *HTTP2Session) RequireAuthentication() error { + s.RegisterConnection(time.Now()) + return nil +} + +// Close closes SSHSession and registers disconnection +func (s *HTTP2Session) Close() { + s.RegisterDisconnection() + s.logger.Infof("Closed session %s for %s %s (%s).", s.ID(), s.NodeID(), s.Agent(), s.Client()) + s.server.Close() + s.control.Close() +} + +// handleRemoteForward listens for TLS connection and connects it to a session +// NOTE: http2 in golang REQUIRES tls. No h2c spec supported +// TODO: instead of manually listening for TCP conns - just listen via a http2 server +// TODO: Currently only handles TCP not TLS +func (s *HTTP2Session) handleRemoteForward(ln net.Listener) { + defer func() { + err := ln.Close() + if err != nil { + s.logger.Debugf("Couldn't close ingress conn: %s", err) + return + } + s.logger.Debugf("Closed ingress conn: %s", ln.Addr().String()) + }() + + if err := s.server.Serve(ln); err != nil { + s.logger.Errorf("Stopped being able to serve ingress traffic: %v+", err) + return + } +} + +// ServeHTTP... +func (s *HTTP2Session) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var resp *http.Response + var err error + for { + obj := s.conns.Get() + + conn, ok := obj.(*http2Tunnel) + if !ok { + s.logger.Error("Got wrong object type from connection pool") + return + } + + // HTTP2 doesn't support these header types, so delete them + // This is if the end client doesn't support HTTP2 + // Taken from go core at net/http2/transport.go + if v := r.Header.Get("Upgrade"); v != "" { + r.Header.Del("Upgrade") + } + if v := r.Header.Get("Transfer-Encoding"); (v != "" && v != "chunked") || len(r.Header["Transfer-Encoding"]) > 1 { + r.Header.Del("Transfer-Encoding") + } + if v := r.Header.Get("Connection"); (v != "" && v != "close" && v != "keep-alive") || len(r.Header["Connection"]) > 1 { + r.Header.Del("Connection") + } + + resp, err = conn.RoundTrip(r) + // TODO: Handle this error + if err != nil { + if conn.ShouldDelete() { + continue + } + s.logger.Warnf("Error with requiest: %+v", err) + if netErr, ok := err.(net.Error); !ok { + // do something not network error related + // maybe retry? + _ = netErr + } + break + } + break + } + + // Delete this so we don't copy it over + // Will be handled by http.ResponseWriter + resp.Header.Del("Content-Length") + + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + w.WriteHeader(resp.StatusCode) + + defer resp.Body.Close() + + nr, err := io.Copy(w, resp.Body) + if err != nil { + s.logger.Errorf("Could not copy response body") + return + } + s.logger.Infof("Copied %d bytes between connection bodies", nr) +} + +func (s *HTTP2Session) openTunnel() error { + msg := &messages.OpenTunnel{ClientID: s.id} + b, err := messages.Pack(msg) + if err != nil { + return fmt.Errorf("Couldn't create a request to open new tunnel: %s", err.Error()) + } + _, err = s.control.Write(b) + if err != nil { + return fmt.Errorf("Failed to send request to open new tunnel: %s", err.Error()) + } + return nil +} + +func (s *HTTP2Session) heartbeat() { + // timer for detecting heartbeat failure + connCheck := time.NewTicker(connCheckInterval) + defer connCheck.Stop() + + for { + select { + case <-connCheck.C: + lastPing := time.Unix(0, atomic.LoadInt64(&s.lastPingAt)) + if time.Since(lastPing) > pingTimeoutInterval { + s.Close() + return + } + } + } +} + +func (s *HTTP2Session) controlLoop() { + b := make([]byte, 1024) + + for { + nr, err := s.control.Read(b) + if err == io.EOF { + continue + } + if err != nil { + s.logger.Errorf("error reading from control: " + err.Error()) + s.Close() + return + } + msg, err := messages.Unpack(b[:nr]) + if err != nil { + s.logger.Errorf("error parsing message from stream: " + err.Error()) + s.Close() + return + } + switch m := msg.(type) { + case *messages.Shutdown: + s.logger.Debugf("Received Shutdown message: %s", m.Error) + s.Close() + return + case *messages.Ping: + s.logger.Debug("Received Ping message.") + atomic.StoreInt64(&s.lastPingAt, time.Now().UnixNano()) + bw, err := messages.Pack(&messages.Pong{}) + if err != nil { + s.logger.Errorf("Couldn't create a Pong message: %s", err.Error()) + } + _, err = s.control.Write(bw) + if err != nil { + s.logger.Errorf("Failed to send Pong message: %s", err.Error()) + } + default: + s.logger.Warn("Unrecognized command. Ignoring.") + } + } +} + +// RegisterConnection creates and stores a new session record +func (s *HTTP2Session) RegisterConnection(t time.Time) error { + return s.store.RegisterConnection(s) +} + +// RegisterDisconnection destroys the session record +func (s *HTTP2Session) RegisterDisconnection() error { + return s.store.RegisterDisconnection(s) +} + +// RegisterEndpoint registers the endpoint and adds it to the current session record +// The endpoint is a particular instance of a running wormhole client +func (s *HTTP2Session) RegisterEndpoint() error { + return s.store.RegisterEndpoint(s) +} + +// UpdateAttribute updates a particular attribute of the current session record +func (s *HTTP2Session) UpdateAttribute(name string, value interface{}) error { + return s.store.UpdateAttribute(s, name, value) +} diff --git a/session/http2_session_test.go b/session/http2_session_test.go new file mode 100644 index 00000000..98845564 --- /dev/null +++ b/session/http2_session_test.go @@ -0,0 +1,262 @@ +package session + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/url" + "os" + "testing" + "time" + + "github.com/garyburd/redigo/redis" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/superfly/tlstest" + _ "github.com/superfly/wormhole/config" + "github.com/superfly/wormhole/messages" + wnet "github.com/superfly/wormhole/net" + "golang.org/x/net/http2" + "gopkg.in/ory-am/dockertest.v3" +) + +var redisPool *redis.Pool +var serverTLSConfig *tls.Config +var clientTLSConfig *tls.Config + +var serverTLSCert tls.Certificate +var serverCrtPEM []byte +var serverKeyPEM []byte + +func TestMain(m *testing.M) { + var rootCrtPEM []byte + var err error + rootCrtPEM, serverCrtPEM, serverKeyPEM, err = tlstest.CreateServerCertKeyPEMPairWithRootCert() + if err != nil { + log.Fatalf("tlstest could not generate x509 certs %v+", err) + } + + serverTLSCert, err = tls.X509KeyPair(serverCrtPEM, serverKeyPEM) + if err != nil { + log.Fatalf("Couldn't create tls cert from keypair %v+", err) + } + + serverTLSConfig = &tls.Config{ + Certificates: []tls.Certificate{serverTLSCert}, + } + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(rootCrtPEM) + + clientTLSConfig = &tls.Config{ + RootCAs: certPool, + ServerName: "127.0.0.1", + } + + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Dockertest could not connect to docker: %s", err) + } + + redisResource, err := pool.Run("redis", "4.0.1", []string{}) + if err != nil { + log.Fatalf("Could not create redis container") + } + + if err := pool.Retry(func() error { + var err error + c, err := redis.DialURL(fmt.Sprintf("redis://localhost:%s", redisResource.GetPort("6379/tcp"))) + if err != nil { + return err + } + _, err = c.Do("PING") + return err + }); err != nil { + log.Fatalf("Could not connect to redis container: %s", err) + } + + redisPool = newRedisPool(fmt.Sprintf("redis://localhost:%s", redisResource.GetPort("6379/tcp"))) + + code := m.Run() + + if err := pool.Purge(redisResource); err != nil { + log.Fatalf("Could not purge redis: %s", err) + } + + os.Exit(code) +} + +func newRedisPool(redisURL string) *redis.Pool { + return &redis.Pool{ + MaxIdle: 3, + IdleTimeout: 240 * time.Second, + Dial: func() (redis.Conn, error) { + conn, err := redis.DialURL(redisURL) + if err != nil { + return nil, err + } + + parsedURL, err := url.Parse(redisURL) + if err != nil { + return nil, err + } + if parsedURL.User != nil { + if password, hasPassword := parsedURL.User.Password(); hasPassword == true { + if _, authErr := conn.Do("AUTH", password); authErr != nil { + conn.Close() + return nil, authErr + } + } + } + return conn, nil + }, + TestOnBorrow: func(conn redis.Conn, t time.Time) error { + if time.Since(t) < time.Minute { + return nil + } + _, err := conn.Do("PING") + return err + }, + } +} + +func newServerClientTLSConns(alpn bool) (serverTLSConn *tls.Conn, clientTLSConn *tls.Conn, err error) { + sConnCh := make(chan *net.TCPConn) + lnAddr := &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 8085, + } + + ln, err := net.ListenTCP("tcp", lnAddr) + if err != nil { + log.Errorf("Error creating TCP listener: %+v", err) + return + } + + go func(ln *net.TCPListener) { + s, err := ln.AcceptTCP() + if err != nil { + log.Errorf("Error accepting TCP listener: %+v", err) + close(sConnCh) + } + if err := ln.Close(); err != nil { + log.Errorf("Error closing listener: %+v", err) + close(sConnCh) + } + sConnCh <- s + }(ln) + + cConn, err := net.DialTCP("tcp", nil, lnAddr) + if err != nil { + log.Errorf("Error dialing TCP conn") + return + } + + sConn, ok := <-sConnCh + if !ok { + return nil, nil, errors.New("Error creating server conn") + } + + sTLSConnCh := make(chan *tls.Conn) + + var wrapFunc func(*net.TCPConn, *tls.Config, wnet.TLSWrapperFunc) (*tls.Conn, error) + if alpn { + wrapFunc = wnet.HTTP2ALPNTLSWrap + } else { + wrapFunc = wnet.GenericTLSWrap + } + + go func(sConn *net.TCPConn) { + sTLSConn, err := wrapFunc(sConn, serverTLSConfig, tls.Server) + if err != nil { + log.Errorf("Error creating tls wrap server") + close(sConnCh) + } + sTLSConnCh <- sTLSConn + }(sConn) + + clientTLSConn, err = wrapFunc(cConn, clientTLSConfig, tls.Client) + if err != nil { + return nil, nil, err + } + + serverTLSConn, ok = <-sTLSConnCh + if !ok { + return nil, nil, errors.New("Error creating server tls wrap") + } + + return +} + +func TestHTTP2Session(t *testing.T) { + sConn, cConn, err := newServerClientTLSConns(false) + assert.NoError(t, err, "Should be no error creating conns") + + args := &HTTP2SessionArgs{ + Logger: log.New(), + NodeID: "test_id", + TLSConfig: serverTLSConfig, + RedisPool: redisPool, + Conn: sConn, + } + + s, err := NewHTTP2Session(args) + assert.NoError(t, err, "Should be no error creating http2 session") + + t.Run("Test_open_tunnel", func(t *testing.T) { + err = s.openTunnel() + assert.NoError(t, err, "Should be no error opening tunnel") + + b := make([]byte, 1024) + nr, err := cConn.Read(b) + assert.NoError(t, err, "Should be no error reading data") + + msg, err := messages.Unpack(b[:nr]) + assert.NoError(t, err, "Should be no error unpacking") + + oTunMsg, ok := msg.(*messages.OpenTunnel) + assert.True(t, ok, "Should be an opentunnel message") + + assert.Equal(t, oTunMsg.ClientID, s.id) + }) + + t.Run("Test_round_trip", func(t *testing.T) { + sHTTPConn, cHTTPConn, err := newServerClientTLSConns(true) + assert.NoError(t, err, "Should be no error getting new conns") + + http2Server := &http2.Server{} + http2ServerConnOpts := &http2.ServeConnOpts{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "test") + }), + } + + go func() { + http2Server.ServeConn(cHTTPConn, http2ServerConnOpts) + assert.False(t, true, "Should never stop serving conn during this test") + }() + + err = s.AddTunnel(sHTTPConn) + assert.NoError(t, err, "Should be no error adding tunnel") + + ln, err := net.Listen("tcp4", ":0") + assert.NoError(t, err, "Should have no error listening") + + go func() { + s.handleRemoteForward(ln) + assert.False(t, true, "Should never stop handling forward during this test") + }() + + resp, err := http.Get(fmt.Sprintf("http://%s", ln.Addr().String())) + assert.NoError(t, err, "Should not have error requesting") + + body, err := ioutil.ReadAll(resp.Body) + assert.NoError(t, err, "Should have no error parsing body") + assert.Equal(t, "test", string(body), "Should have matching request body") + + }) +} diff --git a/session/session.go b/session/session.go index 0231d906..d687f6c3 100644 --- a/session/session.go +++ b/session/session.go @@ -3,7 +3,7 @@ package session import ( "net" - "github.com/Sirupsen/logrus" + "github.com/sirupsen/logrus" "github.com/superfly/wormhole/messages" ) diff --git a/session/ssh_session.go b/session/ssh_session.go index 205c17e4..5c113681 100644 --- a/session/ssh_session.go +++ b/session/ssh_session.go @@ -11,10 +11,10 @@ import ( msgpack "gopkg.in/vmihailenco/msgpack.v2" - "github.com/Sirupsen/logrus" "github.com/garyburd/redigo/redis" "github.com/prometheus/client_golang/prometheus" "github.com/rs/xid" + "github.com/sirupsen/logrus" "github.com/superfly/wormhole/messages" "golang.org/x/crypto/ssh" diff --git a/session/tcp_session.go b/session/tcp_session.go index d8467138..9356aa4f 100644 --- a/session/tcp_session.go +++ b/session/tcp_session.go @@ -7,9 +7,9 @@ import ( "sync/atomic" "time" - "github.com/Sirupsen/logrus" "github.com/garyburd/redigo/redis" "github.com/rs/xid" + "github.com/sirupsen/logrus" "github.com/superfly/wormhole/messages" wnet "github.com/superfly/wormhole/net" )