diff --git a/cli/serverv2/command.go b/cli/serverv2/command.go index 3756ed7..97e789c 100644 --- a/cli/serverv2/command.go +++ b/cli/serverv2/command.go @@ -10,8 +10,7 @@ import ( "time" "github.com/andydunstall/piko/pkg/log" - proxyserver "github.com/andydunstall/piko/serverv2/server/proxy" - upstreamserver "github.com/andydunstall/piko/serverv2/server/upstream" + "github.com/andydunstall/piko/serverv2/reverseproxy" "github.com/andydunstall/piko/serverv2/upstream" rungroup "github.com/oklog/run" "github.com/spf13/cobra" @@ -61,8 +60,8 @@ func run(logger log.Logger) error { } upstreamManager := upstream.NewManager() - proxyServer := proxyserver.NewServer(upstreamManager, logger) - upstreamServer := upstreamserver.NewServer(upstreamManager, nil, logger) + proxyServer := reverseproxy.NewServer(upstreamManager, logger) + upstreamServer := upstream.NewServer(upstreamManager, nil, logger) var group rungroup.Group diff --git a/client/piko.go b/client/piko.go index cb305a1..242a467 100644 --- a/client/piko.go +++ b/client/piko.go @@ -2,9 +2,11 @@ package piko import ( "context" + "encoding/binary" "encoding/json" "errors" "fmt" + "io" "net/url" "sync" "time" @@ -169,8 +171,14 @@ func (p *Piko) receive() { return } + var sz int64 + if err := binary.Read(stream, binary.BigEndian, &sz); err != nil { + p.logger.Warn("failed to read proxy header", zap.Error(err)) + continue + } + var header protocol.ProxyHeader - if err := json.NewDecoder(stream).Decode(&header); err != nil { + if err := json.NewDecoder(io.LimitReader(stream, sz)).Decode(&header); err != nil { p.logger.Warn("failed to read proxy header", zap.Error(err)) continue } diff --git a/serverv2/reverseproxy/logger.go b/serverv2/reverseproxy/logger.go new file mode 100644 index 0000000..c481cd1 --- /dev/null +++ b/serverv2/reverseproxy/logger.go @@ -0,0 +1,49 @@ +package reverseproxy + +import ( + "net/http" + "time" + + "github.com/andydunstall/piko/pkg/log" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +type loggedRequest struct { + Proto string `json:"proto"` + Method string `json:"method"` + Host string `json:"host"` + Path string `json:"path"` + RequestHeaders http.Header `json:"request_headers"` + ResponseHeaders http.Header `json:"response_headers"` + Status int `json:"status"` + Duration string `json:"duration"` +} + +// NewLoggerMiddleware creates logging middleware that logs every request. +func NewLoggerMiddleware(accessLog bool, logger log.Logger) gin.HandlerFunc { + logger = logger.WithSubsystem("reverseproxy.access") + return func(c *gin.Context) { + s := time.Now() + + c.Next() + + req := &loggedRequest{ + Proto: c.Request.Proto, + Method: c.Request.Method, + Host: c.Request.Method, + Path: c.Request.URL.Path, + RequestHeaders: c.Request.Header, + ResponseHeaders: c.Writer.Header(), + Status: c.Writer.Status(), + Duration: time.Since(s).String(), + } + if c.Writer.Status() > http.StatusInternalServerError { + logger.Warn("request", zap.Any("request", req)) + } else if accessLog { + logger.Info("request", zap.Any("request", req)) + } else { + logger.Warn("request", zap.Any("request", req)) + } + } +} diff --git a/serverv2/reverseproxy/reverseproxy.go b/serverv2/reverseproxy/reverseproxy.go new file mode 100644 index 0000000..088841e --- /dev/null +++ b/serverv2/reverseproxy/reverseproxy.go @@ -0,0 +1,100 @@ +package reverseproxy + +import ( + "context" + "net" + "net/http" + "net/http/httputil" + "net/url" + "strings" + + "github.com/andydunstall/piko/pkg/log" +) + +// Handler implements a reverse proxy HTTP handler that accepts requests from +// downstream clients and forwards them to upstream services. +type Handler struct { + upstreams UpstreamPool + + proxy *httputil.ReverseProxy + + logger log.Logger +} + +func NewHandler(upstreams UpstreamPool, logger log.Logger) *Handler { + logger = logger.WithSubsystem("reverseproxy") + handler := &Handler{ + upstreams: upstreams, + logger: logger, + } + + transport := &http.Transport{ + DialContext: handler.dialContext, + // 'connections' to the upstream are multiplexed over a single TCP + // connection so theres no overhead to creating new connections, + // therefore it doesn't make sense to keep them alive. + DisableKeepAlives: true, + } + proxy := &httputil.ReverseProxy{ + Transport: transport, + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(&url.URL{ + Scheme: "http", + Host: r.In.Host, + Path: r.In.URL.Path, + RawQuery: r.In.URL.RawQuery, + }) + }, + } + handler.proxy = proxy + return handler +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + endpointID := endpointIDFromRequest(r) + if endpointID == "" { + h.logger.Warn("request: missing endpoint id") + http.Error(w, `{"message": "missing endpoint id"}`, http.StatusBadGateway) + return + } + + // nolint + ctx := context.WithValue(r.Context(), "_piko_endpoint", endpointID) + r = r.WithContext(ctx) + + h.proxy.ServeHTTP(w, r) +} + +// dialContext dials the endpoint ID in ctx. This is a bit of a hack to work +// with http.Transport. +func (h *Handler) dialContext(ctx context.Context, _, _ string) (net.Conn, error) { + // TODO(andydunstall): Alternatively wrap Transport.RoundTrip and first + // parse the endpoint ID, then decide whether to forward to a local + // connection or a remote node. + endpointID := ctx.Value("_piko_endpoint").(string) + return h.upstreams.Dial(endpointID) +} + +// endpointIDFromRequest returns the endpoint ID from the HTTP request, or an +// empty string if no endpoint ID is specified. +// +// This will check both the 'x-piko-endpoint' header and 'Host' header, where +// x-piko-endpoint takes precedence. +func endpointIDFromRequest(r *http.Request) string { + endpointID := r.Header.Get("x-piko-endpoint") + if endpointID != "" { + return endpointID + } + + host := r.Host + if host != "" && strings.Contains(host, ".") { + // If a host is given and contains a separator, use the bottom-level + // domain as the endpoint ID. + // + // Such as if the domain is 'xyz.piko.example.com', then 'xyz' is the + // endpoint ID. + return strings.Split(host, ".")[0] + } + + return "" +} diff --git a/serverv2/reverseproxy/server.go b/serverv2/reverseproxy/server.go new file mode 100644 index 0000000..1609f80 --- /dev/null +++ b/serverv2/reverseproxy/server.go @@ -0,0 +1,88 @@ +package reverseproxy + +import ( + "context" + "fmt" + "net" + "net/http" + + "github.com/andydunstall/piko/pkg/log" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +type Server struct { + handler *Handler + + router *gin.Engine + + httpServer *http.Server + + logger log.Logger +} + +func NewServer( + upstreams UpstreamPool, + logger log.Logger, +) *Server { + logger = logger.WithSubsystem("reverseproxy") + + router := gin.New() + server := &Server{ + handler: NewHandler(upstreams, logger), + router: router, + httpServer: &http.Server{ + Handler: router, + }, + logger: logger, + } + + // Recover from panics. + server.router.Use(gin.CustomRecoveryWithWriter(nil, server.panicRoute)) + + server.router.Use(NewLoggerMiddleware(true, logger)) + + server.registerRoutes() + + return server +} + +func (s *Server) Serve(ln net.Listener) error { + s.logger.Info("starting http server", zap.String("addr", ln.Addr().String())) + if err := s.httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { + return fmt.Errorf("http serve: %w", err) + } + return nil +} + +func (s *Server) Shutdown(ctx context.Context) error { + return s.httpServer.Shutdown(ctx) +} + +func (s *Server) registerRoutes() { + // Handle not found routes, which includes all proxied endpoints. + s.router.NoRoute(s.notFoundRoute) +} + +// proxyRoute handles proxied requests from proxy clients. +func (s *Server) proxyRoute(c *gin.Context) { + s.handler.ServeHTTP(c.Writer, c.Request) +} + +func (s *Server) notFoundRoute(c *gin.Context) { + s.proxyRoute(c) +} + +func (s *Server) panicRoute(c *gin.Context, err any) { + s.logger.Error( + "handler panic", + zap.String("path", c.FullPath()), + zap.Any("err", err), + ) + c.AbortWithStatus(http.StatusInternalServerError) +} + +func init() { + // Disable Gin debug logs. + gin.SetMode(gin.ReleaseMode) +} diff --git a/serverv2/reverseproxy/upstream.go b/serverv2/reverseproxy/upstream.go new file mode 100644 index 0000000..bc91aeb --- /dev/null +++ b/serverv2/reverseproxy/upstream.go @@ -0,0 +1,10 @@ +package reverseproxy + +import ( + "net" +) + +// UpstreamPool contains the connected upstreams. +type UpstreamPool interface { + Dial(endpointID string) (net.Conn, error) +} diff --git a/serverv2/server/proxy/server.go b/serverv2/server/proxy/server.go deleted file mode 100644 index 233af56..0000000 --- a/serverv2/server/proxy/server.go +++ /dev/null @@ -1,67 +0,0 @@ -package proxy - -import ( - "context" - "net" - "net/http" - "net/http/httputil" - "net/url" - "time" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/serverv2/upstream" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -type Server struct { - server *http.Server - - logger log.Logger -} - -func NewServer(manager *upstream.Manager, logger log.Logger) *Server { - transport := &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - endpointID, _, _ := net.SplitHostPort(addr) - return manager.Dial(endpointID) - }, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - - // TODO(andydunstall): Configure timeouts, access log, ... - proxy := &httputil.ReverseProxy{ - Transport: transport, - Rewrite: func(r *httputil.ProxyRequest) { - // TODO(andydunstall): For now hacky approach assuming - // x-piko-endpoint is set. - u, _ := url.Parse("http://" + r.In.Header.Get("x-piko-endpoint")) - r.SetURL(u) - }, - ErrorLog: logger.StdLogger(zapcore.WarnLevel), - } - return &Server{ - server: &http.Server{ - Handler: proxy, - ErrorLog: logger.StdLogger(zapcore.WarnLevel), - }, - logger: logger, - } -} - -// Serve serves connections on the listener. -func (s *Server) Serve(ln net.Listener) error { - s.logger.Info( - "starting http server", - zap.String("addr", ln.Addr().String()), - ) - return s.server.Serve(ln) -} - -func (s *Server) Shutdown(ctx context.Context) error { - return s.server.Shutdown(ctx) -} diff --git a/serverv2/upstream/manager.go b/serverv2/upstream/manager.go index 94223ad..2543073 100644 --- a/serverv2/upstream/manager.go +++ b/serverv2/upstream/manager.go @@ -1,6 +1,7 @@ package upstream import ( + "encoding/binary" "encoding/json" "fmt" "net" @@ -51,7 +52,14 @@ func (m *Manager) Dial(endpointID string) (net.Conn, error) { header := protocol.ProxyHeader{ EndpointID: endpointID, } - if err := json.NewEncoder(stream).Encode(header); err != nil { + b, err := json.Marshal(header) + if err != nil { + return nil, fmt.Errorf("encode proxy header: %w", err) + } + if err := binary.Write(stream, binary.BigEndian, int64(len(b))); err != nil { + return nil, fmt.Errorf("write proxy header: %w", err) + } + if _, err := stream.Write(b); err != nil { return nil, fmt.Errorf("write proxy header: %w", err) } diff --git a/serverv2/server/upstream/server.go b/serverv2/upstream/server.go similarity index 96% rename from serverv2/server/upstream/server.go rename to serverv2/upstream/server.go index 24c1c23..2322b5a 100644 --- a/serverv2/server/upstream/server.go +++ b/serverv2/upstream/server.go @@ -1,4 +1,4 @@ -package server +package upstream import ( "context" @@ -12,7 +12,6 @@ import ( "github.com/andydunstall/piko/pkg/log" "github.com/andydunstall/piko/pkg/protocol" pikowebsocket "github.com/andydunstall/piko/pkg/websocket" - "github.com/andydunstall/piko/serverv2/upstream" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "go.uber.org/zap" @@ -22,7 +21,7 @@ import ( // Server accepts connections from upstream services. type Server struct { - manager *upstream.Manager + manager *Manager router *gin.Engine @@ -34,7 +33,7 @@ type Server struct { } func NewServer( - manager *upstream.Manager, + manager *Manager, tlsConfig *tls.Config, logger log.Logger, ) *Server { @@ -166,7 +165,7 @@ func (s *Server) handleListenRequest(sess *muxado.Heartbeat, stream muxado.Typed // TODO(andydunstall): Handle unregistering. - upstream := &upstream.Upstream{ + upstream := &Upstream{ EndpointID: req.EndpointID, Sess: sess, }