Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

serverv2: add reverse proxy #61

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions cli/serverv2/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import (
"time"

"github.com/andydunstall/piko/pkg/log"
proxyserver "github.com/andydunstall/piko/serverv2/server/proxy"
upstreamserver "github.com/andydunstall/piko/serverv2/server/upstream"
"github.com/andydunstall/piko/serverv2/reverseproxy"
"github.com/andydunstall/piko/serverv2/upstream"
rungroup "github.com/oklog/run"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -61,8 +60,8 @@ func run(logger log.Logger) error {
}

upstreamManager := upstream.NewManager()
proxyServer := proxyserver.NewServer(upstreamManager, logger)
upstreamServer := upstreamserver.NewServer(upstreamManager, nil, logger)
proxyServer := reverseproxy.NewServer(upstreamManager, logger)
upstreamServer := upstream.NewServer(upstreamManager, nil, logger)

var group rungroup.Group

Expand Down
10 changes: 9 additions & 1 deletion client/piko.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package piko

import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
"sync"
"time"
Expand Down Expand Up @@ -169,8 +171,14 @@ func (p *Piko) receive() {
return
}

var sz int64
if err := binary.Read(stream, binary.BigEndian, &sz); err != nil {
p.logger.Warn("failed to read proxy header", zap.Error(err))
continue
}

var header protocol.ProxyHeader
if err := json.NewDecoder(stream).Decode(&header); err != nil {
if err := json.NewDecoder(io.LimitReader(stream, sz)).Decode(&header); err != nil {
p.logger.Warn("failed to read proxy header", zap.Error(err))
continue
}
Expand Down
49 changes: 49 additions & 0 deletions serverv2/reverseproxy/logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package reverseproxy

import (
"net/http"
"time"

"github.com/andydunstall/piko/pkg/log"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)

type loggedRequest struct {
Proto string `json:"proto"`
Method string `json:"method"`
Host string `json:"host"`
Path string `json:"path"`
RequestHeaders http.Header `json:"request_headers"`
ResponseHeaders http.Header `json:"response_headers"`
Status int `json:"status"`
Duration string `json:"duration"`
}

// NewLoggerMiddleware creates logging middleware that logs every request.
func NewLoggerMiddleware(accessLog bool, logger log.Logger) gin.HandlerFunc {
logger = logger.WithSubsystem("reverseproxy.access")
return func(c *gin.Context) {
s := time.Now()

c.Next()

req := &loggedRequest{
Proto: c.Request.Proto,
Method: c.Request.Method,
Host: c.Request.Method,
Path: c.Request.URL.Path,
RequestHeaders: c.Request.Header,
ResponseHeaders: c.Writer.Header(),
Status: c.Writer.Status(),
Duration: time.Since(s).String(),
}
if c.Writer.Status() > http.StatusInternalServerError {
logger.Warn("request", zap.Any("request", req))
} else if accessLog {
logger.Info("request", zap.Any("request", req))
} else {
logger.Warn("request", zap.Any("request", req))
}
}
}
100 changes: 100 additions & 0 deletions serverv2/reverseproxy/reverseproxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package reverseproxy

import (
"context"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"

"github.com/andydunstall/piko/pkg/log"
)

// Handler implements a reverse proxy HTTP handler that accepts requests from
// downstream clients and forwards them to upstream services.
type Handler struct {
upstreams UpstreamPool

proxy *httputil.ReverseProxy

logger log.Logger
}

func NewHandler(upstreams UpstreamPool, logger log.Logger) *Handler {
logger = logger.WithSubsystem("reverseproxy")
handler := &Handler{
upstreams: upstreams,
logger: logger,
}

transport := &http.Transport{
DialContext: handler.dialContext,
// 'connections' to the upstream are multiplexed over a single TCP
// connection so theres no overhead to creating new connections,
// therefore it doesn't make sense to keep them alive.
DisableKeepAlives: true,
}
proxy := &httputil.ReverseProxy{
Transport: transport,
Rewrite: func(r *httputil.ProxyRequest) {
r.SetURL(&url.URL{
Scheme: "http",
Host: r.In.Host,
Path: r.In.URL.Path,
RawQuery: r.In.URL.RawQuery,
})
},
}
handler.proxy = proxy
return handler
}

func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
endpointID := endpointIDFromRequest(r)
if endpointID == "" {
h.logger.Warn("request: missing endpoint id")
http.Error(w, `{"message": "missing endpoint id"}`, http.StatusBadGateway)
return
}

// nolint
ctx := context.WithValue(r.Context(), "_piko_endpoint", endpointID)
r = r.WithContext(ctx)

h.proxy.ServeHTTP(w, r)
}

// dialContext dials the endpoint ID in ctx. This is a bit of a hack to work
// with http.Transport.
func (h *Handler) dialContext(ctx context.Context, _, _ string) (net.Conn, error) {
// TODO(andydunstall): Alternatively wrap Transport.RoundTrip and first
// parse the endpoint ID, then decide whether to forward to a local
// connection or a remote node.
endpointID := ctx.Value("_piko_endpoint").(string)
return h.upstreams.Dial(endpointID)
}

// endpointIDFromRequest returns the endpoint ID from the HTTP request, or an
// empty string if no endpoint ID is specified.
//
// This will check both the 'x-piko-endpoint' header and 'Host' header, where
// x-piko-endpoint takes precedence.
func endpointIDFromRequest(r *http.Request) string {
endpointID := r.Header.Get("x-piko-endpoint")
if endpointID != "" {
return endpointID
}

host := r.Host
if host != "" && strings.Contains(host, ".") {
// If a host is given and contains a separator, use the bottom-level
// domain as the endpoint ID.
//
// Such as if the domain is 'xyz.piko.example.com', then 'xyz' is the
// endpoint ID.
return strings.Split(host, ".")[0]
}

return ""
}
88 changes: 88 additions & 0 deletions serverv2/reverseproxy/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package reverseproxy

import (
"context"
"fmt"
"net"
"net/http"

"github.com/andydunstall/piko/pkg/log"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)

type Server struct {
handler *Handler

router *gin.Engine

httpServer *http.Server

logger log.Logger
}

func NewServer(
upstreams UpstreamPool,
logger log.Logger,
) *Server {
logger = logger.WithSubsystem("reverseproxy")

router := gin.New()
server := &Server{
handler: NewHandler(upstreams, logger),
router: router,
httpServer: &http.Server{
Handler: router,
},
logger: logger,
}

// Recover from panics.
server.router.Use(gin.CustomRecoveryWithWriter(nil, server.panicRoute))

server.router.Use(NewLoggerMiddleware(true, logger))

server.registerRoutes()

return server
}

func (s *Server) Serve(ln net.Listener) error {
s.logger.Info("starting http server", zap.String("addr", ln.Addr().String()))
if err := s.httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
return fmt.Errorf("http serve: %w", err)
}
return nil
}

func (s *Server) Shutdown(ctx context.Context) error {
return s.httpServer.Shutdown(ctx)
}

func (s *Server) registerRoutes() {
// Handle not found routes, which includes all proxied endpoints.
s.router.NoRoute(s.notFoundRoute)
}

// proxyRoute handles proxied requests from proxy clients.
func (s *Server) proxyRoute(c *gin.Context) {
s.handler.ServeHTTP(c.Writer, c.Request)
}

func (s *Server) notFoundRoute(c *gin.Context) {
s.proxyRoute(c)
}

func (s *Server) panicRoute(c *gin.Context, err any) {
s.logger.Error(
"handler panic",
zap.String("path", c.FullPath()),
zap.Any("err", err),
)
c.AbortWithStatus(http.StatusInternalServerError)
}

func init() {
// Disable Gin debug logs.
gin.SetMode(gin.ReleaseMode)
}
10 changes: 10 additions & 0 deletions serverv2/reverseproxy/upstream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package reverseproxy

import (
"net"
)

// UpstreamPool contains the connected upstreams.
type UpstreamPool interface {
Dial(endpointID string) (net.Conn, error)
}
67 changes: 0 additions & 67 deletions serverv2/server/proxy/server.go

This file was deleted.

10 changes: 9 additions & 1 deletion serverv2/upstream/manager.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package upstream

import (
"encoding/binary"
"encoding/json"
"fmt"
"net"
Expand Down Expand Up @@ -51,7 +52,14 @@ func (m *Manager) Dial(endpointID string) (net.Conn, error) {
header := protocol.ProxyHeader{
EndpointID: endpointID,
}
if err := json.NewEncoder(stream).Encode(header); err != nil {
b, err := json.Marshal(header)
if err != nil {
return nil, fmt.Errorf("encode proxy header: %w", err)
}
if err := binary.Write(stream, binary.BigEndian, int64(len(b))); err != nil {
return nil, fmt.Errorf("write proxy header: %w", err)
}
if _, err := stream.Write(b); err != nil {
return nil, fmt.Errorf("write proxy header: %w", err)
}

Expand Down
Loading
Loading