diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 5899d2e..739425e 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -38,9 +38,6 @@ jobs: - name: Integration Tests run: go test ./... -tags integration -v - - name: System Tests - run: go test ./tests -tags system -v - lint: runs-on: ubuntu-latest needs: build diff --git a/agentv2/client/piko.go b/agent/client/client.go similarity index 74% rename from agentv2/client/piko.go rename to agent/client/client.go index 2839a8f..11df210 100644 --- a/agentv2/client/piko.go +++ b/agent/client/client.go @@ -1,4 +1,4 @@ -package piko +package client import ( "context" @@ -11,17 +11,17 @@ const ( defaultURL = "ws://localhost:8001" ) -// Piko manages registering and listening on endpoints. +// Client manages registering listeners with Piko. // // The client establishes an outbound-only connection to the server for each // listener. Proxied connections for the listener are then multiplexed over // that outbound connection. Therefore the client never exposes a port. -type Piko struct { +type Client struct { options options logger log.Logger } -func New(opts ...Option) *Piko { +func New(opts ...Option) *Client { options := options{ token: "", url: defaultURL, @@ -31,7 +31,7 @@ func New(opts ...Option) *Piko { o.apply(&options) } - return &Piko{ + return &Client{ options: options, logger: options.logger, } @@ -42,6 +42,6 @@ func New(opts ...Option) *Piko { // Listen will block until the listener has been registered. // // The returned [Listener] is a [net.Listener]. -func (p *Piko) Listen(ctx context.Context, endpointID string) (Listener, error) { - return listen(ctx, endpointID, p.options, p.logger) +func (c *Client) Listen(ctx context.Context, endpointID string) (Listener, error) { + return listen(ctx, endpointID, c.options, c.logger) } diff --git a/agentv2/client/listener.go b/agent/client/listener.go similarity index 95% rename from agentv2/client/listener.go rename to agent/client/listener.go index bf674d2..db483b2 100644 --- a/agentv2/client/listener.go +++ b/agent/client/listener.go @@ -1,4 +1,4 @@ -package piko +package client import ( "context" @@ -6,7 +6,6 @@ import ( "fmt" "net" "net/url" - "sync" "time" "github.com/andydunstall/piko/pkg/backoff" @@ -48,8 +47,7 @@ type Listener interface { type listener struct { endpointID string - mux *mux.Session - muxMu sync.Mutex + mux *mux.Session options options @@ -101,9 +99,7 @@ func (l *listener) Accept() (net.Conn, error) { return nil, err } - l.muxMu.Lock() l.mux = mux - l.muxMu.Unlock() } } @@ -114,11 +110,7 @@ func (l *listener) Addr() net.Addr { func (l *listener) Close() error { l.closeCancel() - l.muxMu.Lock() - err := l.mux.Close() - l.muxMu.Unlock() - - return err + return l.mux.Close() } func (l *listener) EndpointID() string { diff --git a/agentv2/client/options.go b/agent/client/options.go similarity index 98% rename from agentv2/client/options.go rename to agent/client/options.go index 5b4903c..7f88443 100644 --- a/agentv2/client/options.go +++ b/agent/client/options.go @@ -1,4 +1,4 @@ -package piko +package client import ( "crypto/tls" diff --git a/agent/config/config.go b/agent/config/config.go index f9e9851..316aacf 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -4,112 +4,79 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "net" "net/url" "os" + "strconv" "time" "github.com/andydunstall/piko/pkg/log" "github.com/spf13/pflag" ) -type EndpointConfig struct { - ID string `json:"id" yaml:"id"` +type ListenerConfig struct { + // EndpointID is the endpoint ID to register. + EndpointID string `json:"endpoint_id" yaml:"endpoint_id"` + + // Addr is the address of the upstream service to forward to. Addr string `json:"addr" yaml:"addr"` -} -type ServerConfig struct { - // URL is the server URL. - URL string `json:"url" yaml:"url"` - HeartbeatInterval time.Duration `json:"heartbeat_interval" yaml:"heartbeat_interval"` - HeartbeatTimeout time.Duration `json:"heartbeat_timeout" yaml:"heartbeat_timeout"` - ReconnectMinBackoff time.Duration `json:"reconnect_min_backoff" yaml:"reconnect_min_backoff"` - ReconnectMaxBackoff time.Duration `json:"reconnect_max_backoff" yaml:"reconnect_max_backoff"` + // AccessLog indicates whether to log all incoming connections and requests + // for the endpoint. + AccessLog bool `json:"access_log" yaml:"access_log"` + + // Timeout is the timeout to forward incoming requests to the upstream. + Timeout time.Duration `json:"timeout" yaml:"timeout"` } -func (c *ServerConfig) Validate() error { - if c.URL == "" { - return fmt.Errorf("missing url") +// URL parses the given upstream address into a URL. Return false if the +// address is invalid. +// +// The addr may be either a full URL, a host and port or just a port. +func (c *ListenerConfig) URL() (*url.URL, bool) { + // Port only. + port, err := strconv.Atoi(c.Addr) + if err == nil && port >= 0 && port < 0xffff { + return &url.URL{ + Scheme: "http", + Host: "localhost:" + c.Addr, + }, true } - if _, err := url.Parse(c.URL); err != nil { - return fmt.Errorf("invalid url: %w", err) + + // Host and port. + host, portStr, err := net.SplitHostPort(c.Addr) + if err == nil { + return &url.URL{ + Scheme: "http", + Host: net.JoinHostPort(host, portStr), + }, true } - if c.HeartbeatInterval == 0 { - return fmt.Errorf("missing heartbeat interval") + + // URL. + u, err := url.Parse(c.Addr) + if err == nil && u.Scheme != "" && u.Host != "" { + return u, true } - if c.HeartbeatTimeout == 0 { - return fmt.Errorf("missing heartbeat timeout") + + return nil, false +} + +func (c *ListenerConfig) Validate() error { + if c.EndpointID == "" { + return fmt.Errorf("missing endpoint id") } - if c.ReconnectMinBackoff == 0 { - return fmt.Errorf("missing reconnect min backoff") + if c.Addr == "" { + return fmt.Errorf("missing addr") } - if c.ReconnectMaxBackoff == 0 { - return fmt.Errorf("missing reconnect max backoff") + if _, ok := c.URL(); !ok { + return fmt.Errorf("invalid addr") + } + if c.Timeout == 0 { + return fmt.Errorf("missing timeout") } return nil } -func (c *ServerConfig) RegisterFlags(fs *pflag.FlagSet) { - fs.StringVar( - &c.URL, - "server.url", - "http://localhost:8001", - ` -Piko server URL. - -The listener will add path /piko/v1/listener/:endpoint_id to the given URL, -so if you include a path it will be used as a prefix. - -Note Piko connects to the server with WebSockets, so will replace http/https -with ws/wss (you can configure either).`, - ) - fs.DurationVar( - &c.HeartbeatInterval, - "server.heartbeat-interval", - time.Second*10, - ` -Heartbeat interval. - -To verify the connection to the server is ok, the listener sends a -heartbeat to the upstream at the '--server.heartbeat-interval' -interval, with a timeout of '--server.heartbeat-timeout'.`, - ) - fs.DurationVar( - &c.HeartbeatTimeout, - "server.heartbeat-timeout", - time.Second*10, - ` -Heartbeat timeout. - -To verify the connection to the server is ok, the listener sends a -heartbeat to the upstream at the '--server.heartbeat-interval' -interval, with a timeout of '--server.heartbeat-timeout'.`, - ) - fs.DurationVar( - &c.ReconnectMinBackoff, - "server.reconnect-min-backoff", - time.Millisecond*500, - ` -Minimum backoff when reconnecting to the server.`, - ) - fs.DurationVar( - &c.ReconnectMaxBackoff, - "server.reconnect-max-backoff", - time.Second*15, - ` -Maximum backoff when reconnecting to the server.`, - ) -} - -type AuthConfig struct { - APIKey string `json:"api_key" yaml:"api_key"` -} - -// ForwarderConfig contains the configuration for how to forward requests -// from Piko. -type ForwarderConfig struct { - Timeout time.Duration `json:"timeout" yaml:"timeout"` -} - type TLSConfig struct { // RootCAs contains a path to root certificate authorities to validate // the TLS connection to the Piko server. @@ -118,10 +85,11 @@ type TLSConfig struct { RootCAs string `json:"root_cas" yaml:"root_cas"` } -func (c *TLSConfig) RegisterFlags(fs *pflag.FlagSet) { +func (c *TLSConfig) RegisterFlags(fs *pflag.FlagSet, prefix string) { + prefix = prefix + ".tls." fs.StringVar( &c.RootCAs, - "tls.root-cas", + prefix+"root-cas", "", ` A path to a certificate PEM file containing root certificiate authorities to @@ -152,82 +120,148 @@ func (c *TLSConfig) Load() (*tls.Config, error) { return tlsConfig, nil } -type AdminConfig struct { +type ConnectConfig struct { + // URL is the Piko server URL to connect to. + URL string + + // Token is a token to authenticate with the Piko server. + Token string + + // Timeout is the timeout attempting to connect to the Piko server on + // boot. + Timeout time.Duration `json:"timeout" yaml:"timeout"` + + TLS TLSConfig `json:"tls" yaml:"tls"` +} + +func (c *ConnectConfig) Validate() error { + if c.URL == "" { + return fmt.Errorf("missing url") + } + if _, err := url.Parse(c.URL); err != nil { + return fmt.Errorf("invalid url: %w", err) + } + if c.Timeout == 0 { + return fmt.Errorf("missing timeout") + } + return nil +} + +func (c *ConnectConfig) RegisterFlags(fs *pflag.FlagSet) { + fs.StringVar( + &c.URL, + "connect.url", + "http://localhost:8001", + ` +The Piko server URL to connect to. Note this must be configured to use the +Piko server 'upstream' port.`, + ) + + fs.StringVar( + &c.Token, + "connect.token", + "", + ` +Token is a token to authenticate with the Piko server.`, + ) + + fs.DurationVar( + &c.Timeout, + "connect.timeout", + time.Second*30, + ` +Timeout attempting to connect to the Piko server on boot. Note if the agent +is disconnected after the initial connection succeeds it will keep trying to +reconnect.`, + ) + + c.TLS.RegisterFlags(fs, "connect") +} + +type ServerConfig struct { // BindAddr is the address to bind to listen for incoming HTTP connections. BindAddr string `json:"bind_addr" yaml:"bind_addr"` } -func (c *AdminConfig) Validate() error { +func (c *ServerConfig) Validate() error { if c.BindAddr == "" { return fmt.Errorf("missing bind addr") } return nil } +func (c *ServerConfig) RegisterFlags(fs *pflag.FlagSet) { + fs.StringVar( + &c.BindAddr, + "server.bind-addr", + ":5000", + ` +The host/port to bind the server to. + +If the host is unspecified it defaults to all listeners, such as +'--server.bind-addr :5000' will listen on '0.0.0.0:5000'.`, + ) +} + type Config struct { - Endpoints []EndpointConfig `json:"endpoints" yaml:"endpoints"` - Server ServerConfig `json:"server" yaml:"server"` - Auth AuthConfig `json:"auth" yaml:"auth"` - Forwarder ForwarderConfig `json:"forwarder" yaml:"forwarder"` - TLS TLSConfig `json:"tls" yaml:"tls"` - Admin AdminConfig `json:"admin" yaml:"admin"` - Log log.Config `json:"log" yaml:"log"` + Listeners []ListenerConfig `json:"listeners" yaml:"listeners"` + + Connect ConnectConfig `json:"connect" yaml:"connect"` + + Server ServerConfig `json:"server" yaml:"server"` + + Log log.Config `json:"log" yaml:"log"` + + // GracePeriod is the duration to gracefully shutdown the agent. During + // the grace period, listeners and idle connections are closed, then waits + // for active requests to complete and closes their connections. + GracePeriod time.Duration `json:"grace_period" yaml:"grace_period"` } func (c *Config) Validate() error { - if len(c.Endpoints) == 0 { - return fmt.Errorf("must have at least one endpoint") + // Note don't validate the number of listeners, as some commands don't + // require any. + for _, e := range c.Listeners { + if err := e.Validate(); err != nil { + if e.EndpointID != "" { + return fmt.Errorf("listener: %s: %w", e.EndpointID, err) + } + return fmt.Errorf("listener: %w", err) + } + } + + if err := c.Connect.Validate(); err != nil { + return fmt.Errorf("connect: %w", err) } if err := c.Server.Validate(); err != nil { return fmt.Errorf("server: %w", err) } - if err := c.Admin.Validate(); err != nil { - return fmt.Errorf("admin: %w", err) - } + if err := c.Log.Validate(); err != nil { return fmt.Errorf("log: %w", err) } + + if c.GracePeriod == 0 { + return fmt.Errorf("missing grace period") + } + return nil } func (c *Config) RegisterFlags(fs *pflag.FlagSet) { + c.Connect.RegisterFlags(fs) c.Server.RegisterFlags(fs) - - fs.StringVar( - &c.Auth.APIKey, - "auth.api-key", - "", - ` -An API key to authenticate the connection to Piko.`, - ) + c.Log.RegisterFlags(fs) fs.DurationVar( - &c.Forwarder.Timeout, - "forwarder.timeout", - time.Second*10, - ` -Forwarder timeout. - -This is the timeout between a listener receiving a request from Piko then -forwarding it to the configured forward address, and receiving a response. - -If the upstream does not respond within the given timeout a -'504 Gateway Timeout' is returned to the client.`, - ) - - c.TLS.RegisterFlags(fs) - - fs.StringVar( - &c.Admin.BindAddr, - "admin.bind-addr", - ":9000", + &c.GracePeriod, + "grace-period", + time.Minute, ` -The host/port to listen for incoming admin connections. - -If the host is unspecified it defaults to all listeners, such as -'--admin.bind-addr :9000' will listen on '0.0.0.0:9000'`, +Maximum duration after a shutdown signal is received (SIGTERM or +SIGINT) to gracefully shutdown each listener. +`, ) - c.Log.RegisterFlags(fs) } diff --git a/agentv2/config/config_test.go b/agent/config/config_test.go similarity index 89% rename from agentv2/config/config_test.go rename to agent/config/config_test.go index 4abd8ac..a20358a 100644 --- a/agentv2/config/config_test.go +++ b/agent/config/config_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestEndpointConfig_URL(t *testing.T) { +func TestListenerConfig_URL(t *testing.T) { tests := []struct { addr string url *url.URL @@ -43,7 +43,7 @@ func TestEndpointConfig_URL(t *testing.T) { for _, tt := range tests { tt := tt // for t.Parallel t.Run(tt.addr, func(t *testing.T) { - conf := &EndpointConfig{Addr: tt.addr} + conf := &ListenerConfig{Addr: tt.addr} u, ok := conf.URL() if !ok { assert.Equal(t, tt.ok, ok) diff --git a/agent/endpoint.go b/agent/endpoint.go deleted file mode 100644 index c9bb61e..0000000 --- a/agent/endpoint.go +++ /dev/null @@ -1,162 +0,0 @@ -package agent - -import ( - "context" - "crypto/tls" - "errors" - "fmt" - "net/http" - "net/url" - - "github.com/andydunstall/piko/agent/config" - "github.com/andydunstall/piko/pkg/backoff" - "github.com/andydunstall/piko/pkg/conn" - "github.com/andydunstall/piko/pkg/conn/websocket" - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/pkg/rpc" - "go.uber.org/zap" -) - -// Endpoint is responsible for registering with the Piko server then forwarding -// incoming requests to the forward address. -type Endpoint struct { - endpointID string - forwardAddr string - - forwarder *forwarder - - rpcServer *rpcServer - - conf *config.Config - - tlsConfig *tls.Config - - metrics *Metrics - - logger log.Logger -} - -func NewEndpoint( - endpointID string, - forwardAddr string, - conf *config.Config, - tlsConfig *tls.Config, - metrics *Metrics, - logger log.Logger, -) *Endpoint { - e := &Endpoint{ - endpointID: endpointID, - forwardAddr: forwardAddr, - forwarder: newForwarder( - endpointID, - forwardAddr, - conf.Forwarder.Timeout, - metrics, - logger, - ), - conf: conf, - tlsConfig: tlsConfig, - metrics: metrics, - logger: logger.WithSubsystem("endpoint").With(zap.String("endpoint-id", endpointID)), - } - e.rpcServer = newRPCServer(e, logger) - return e -} - -func (e *Endpoint) Run(ctx context.Context) error { - e.logger.Info( - "registering endpoint", - zap.String("forward-addr", e.forwardAddr), - ) - - for { - stream, err := e.connect(ctx) - if err != nil { - // connect only returns an error if it gets a non-retryable - // response or the context is cancelled, therefore return. - if errors.Is(err, context.Canceled) { - return nil - } - return err - } - defer stream.Close() - - e.logger.Debug("connected to server", zap.String("url", e.serverURL())) - - if err := stream.Monitor( - ctx, - e.conf.Server.HeartbeatInterval, - e.conf.Server.HeartbeatTimeout, - ); err != nil { - if ctx.Err() != nil { - // Shutdown. - return nil - } - - // Reconnect. - e.logger.Warn("disconnected", zap.Error(err)) - } - } -} - -func (e *Endpoint) ProxyHTTP(r *http.Request) (*http.Response, error) { - return e.forwarder.Forward(r) -} - -// connnect attempts to connect to the server. -// -// Retries with backoff until either the given context is cancelled or it gets -// a non-retryable response (such as an authentication error). -func (e *Endpoint) connect(ctx context.Context) (rpc.Stream, error) { - backoff := backoff.New( - // Retry forever. - 0, - e.conf.Server.ReconnectMinBackoff, - e.conf.Server.ReconnectMaxBackoff, - ) - for { - c, err := websocket.Dial( - ctx, - e.serverURL(), - websocket.WithToken(e.conf.Auth.APIKey), - websocket.WithTLSConfig(e.tlsConfig), - ) - if err == nil { - return rpc.NewStream(c, e.rpcServer.Handler(), e.logger), nil - } - - var retryableError *conn.RetryableError - if !errors.As(err, &retryableError) { - e.logger.Error( - "failed to connect to server; non-retryable", - zap.String("url", e.serverURL()), - zap.Error(err), - ) - return nil, fmt.Errorf("connect: %w", err) - } - - e.logger.Warn( - "failed to connect to server; retrying", - zap.String("url", e.serverURL()), - zap.Error(err), - ) - - if !backoff.Wait(ctx) { - return nil, ctx.Err() - } - } -} - -func (e *Endpoint) serverURL() string { - // Already verified URL in Config.Validate. - url, _ := url.Parse(e.conf.Server.URL) - url.Path = "/piko/v1/listener/" + e.endpointID - if url.Scheme == "http" { - url.Scheme = "ws" - } - if url.Scheme == "https" { - url.Scheme = "wss" - } - - return url.String() -} diff --git a/agent/forwarder.go b/agent/forwarder.go deleted file mode 100644 index 714daea..0000000 --- a/agent/forwarder.go +++ /dev/null @@ -1,117 +0,0 @@ -package agent - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "time" - - "github.com/andydunstall/piko/pkg/log" - "github.com/prometheus/client_golang/prometheus" - "go.uber.org/zap" -) - -var ( - errUpstreamTimeout = errors.New("upstream timeout") - errUpstreamUnreachable = errors.New("upstream unreachable") -) - -// forwarder manages forwarding incoming HTTP requests to the configured -// upstream. -type forwarder struct { - endpointID string - addr string - timeout time.Duration - - client *http.Client - - metrics *Metrics - - logger log.Logger -} - -func newForwarder( - endpointID string, - addr string, - timeout time.Duration, - metrics *Metrics, - logger log.Logger, -) *forwarder { - return &forwarder{ - endpointID: endpointID, - addr: addr, - timeout: timeout, - client: &http.Client{}, - metrics: metrics, - logger: logger.WithSubsystem("forwarder"), - } -} - -func (f *forwarder) Forward(req *http.Request) (*http.Response, error) { - ctx, cancel := context.WithTimeout(context.Background(), f.timeout) - defer cancel() - - req = req.WithContext(ctx) - - req.URL.Scheme = "http" - req.URL.Host = f.addr - req.RequestURI = "" - - start := time.Now() - - resp, err := f.client.Do(req) - if err != nil { - f.logger.Warn( - "failed to forward request", - zap.String("method", req.Method), - zap.String("host", req.URL.Host), - zap.String("path", req.URL.Path), - zap.Error(err), - ) - - f.metrics.ForwardErrorsTotal.With(prometheus.Labels{ - "endpoint_id": f.endpointID, - }).Inc() - - if errors.Is(err, context.DeadlineExceeded) { - return nil, errUpstreamTimeout - } - return nil, errUpstreamUnreachable - } - - // We must copy the response body before replying. The request context is - // cancelled when Forward returns so attempting top copy the body again - // may fail. - // - // TODO(andydunstall): Understand this better as should be avoidable. - var buf bytes.Buffer - if _, err := io.Copy(&buf, resp.Body); err != nil { - return nil, fmt.Errorf("copy body: %w", err) - } - resp.Body.Close() - resp.Body = io.NopCloser(&buf) - - f.metrics.ForwardRequestsTotal.With(prometheus.Labels{ - "method": req.Method, - "status": strconv.Itoa(resp.StatusCode), - "endpoint_id": f.endpointID, - }).Inc() - f.metrics.ForwardRequestLatency.With(prometheus.Labels{ - "status": strconv.Itoa(resp.StatusCode), - "endpoint_id": f.endpointID, - }).Observe(float64(time.Since(start).Milliseconds()) / 1000) - - f.logger.Debug( - "forward", - zap.String("method", req.Method), - zap.String("host", req.URL.Host), - zap.String("path", req.URL.Path), - zap.Int("status", resp.StatusCode), - ) - - return resp, nil -} diff --git a/agent/metrics.go b/agent/metrics.go deleted file mode 100644 index a9e0e1c..0000000 --- a/agent/metrics.go +++ /dev/null @@ -1,57 +0,0 @@ -package agent - -import "github.com/prometheus/client_golang/prometheus" - -type Metrics struct { - // ForwardRequestsTotal is the total number of requests send to the - // forward addr. Labelled by method, status code and endpoint ID. - ForwardRequestsTotal *prometheus.CounterVec - // ForwardErrorsTotal is the total number of errors from the forward - // address (not including bad status codes). Labelled by endpoint ID. - ForwardErrorsTotal *prometheus.CounterVec - // ForwardRequestLatency is a histogram of the latency of requests to the - // forward address in seconds. Labelled by response status code and - // endpoint ID. - ForwardRequestLatency *prometheus.HistogramVec -} - -func NewMetrics() *Metrics { - return &Metrics{ - ForwardRequestsTotal: prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: "piko", - Subsystem: "forward", - Name: "requests_total", - Help: "Total requests to the forward address.", - }, - []string{"status", "method", "endpoint_id"}, - ), - ForwardErrorsTotal: prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: "piko", - Subsystem: "forward", - Name: "errors_total", - Help: "Total errors from the forward address.", - }, - []string{"endpoint_id"}, - ), - ForwardRequestLatency: prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: "piko", - Subsystem: "forward", - Name: "request_latency_seconds", - Help: "Forward request latency in seconds", - Buckets: prometheus.DefBuckets, - }, - []string{"status", "endpoint_id"}, - ), - } -} - -func (m *Metrics) Register(registry *prometheus.Registry) { - registry.MustRegister( - m.ForwardRequestsTotal, - m.ForwardErrorsTotal, - m.ForwardRequestLatency, - ) -} diff --git a/agent/reverseproxy/reverseproxy.go b/agent/reverseproxy/reverseproxy.go new file mode 100644 index 0000000..732e38f --- /dev/null +++ b/agent/reverseproxy/reverseproxy.go @@ -0,0 +1,78 @@ +package reverseproxy + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httputil" + "time" + + "github.com/andydunstall/piko/agent/config" + "github.com/andydunstall/piko/pkg/log" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type ReverseProxy struct { + proxy *httputil.ReverseProxy + + timeout time.Duration + + logger log.Logger +} + +func NewReverseProxy(conf config.ListenerConfig, logger log.Logger) *ReverseProxy { + u, ok := conf.URL() + if !ok { + // We've already verified the address on boot so don't need to handle + // the error. + panic("invalid addr: " + conf.Addr) + } + + proxy := httputil.NewSingleHostReverseProxy(u) + proxy.ErrorLog = logger.StdLogger(zapcore.WarnLevel) + rp := &ReverseProxy{ + proxy: proxy, + timeout: conf.Timeout, + logger: logger, + } + proxy.ErrorHandler = rp.errorHandler + return rp +} + +func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if p.timeout != 0 { + ctx, cancel := context.WithTimeout(r.Context(), p.timeout) + defer cancel() + + r = r.WithContext(ctx) + } + + p.proxy.ServeHTTP(w, r) +} + +func (p *ReverseProxy) errorHandler(w http.ResponseWriter, _ *http.Request, err error) { + p.logger.Warn("proxy request", zap.Error(err)) + + if errors.Is(err, context.DeadlineExceeded) { + _ = errorResponse(w, http.StatusGatewayTimeout, "upstream timeout") + return + } + _ = errorResponse(w, http.StatusBadGateway, "upstream unreachable") +} + +type errorMessage struct { + Error string `json:"error"` +} + +func errorResponse(w http.ResponseWriter, statusCode int, message string) error { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(statusCode) + + m := &errorMessage{ + Error: message, + } + return json.NewEncoder(w).Encode(m) +} diff --git a/agent/reverseproxy/reverseproxy_test.go b/agent/reverseproxy/reverseproxy_test.go new file mode 100644 index 0000000..ad56e2e --- /dev/null +++ b/agent/reverseproxy/reverseproxy_test.go @@ -0,0 +1,109 @@ +package reverseproxy + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/andydunstall/piko/agent/config" + "github.com/andydunstall/piko/pkg/log" + "github.com/stretchr/testify/assert" +) + +func TestReverseProxy_Forward(t *testing.T) { + t.Run("ok", func(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/foo/bar", r.URL.Path) + assert.Equal(t, "a=b", r.URL.RawQuery) + + buf := new(strings.Builder) + // nolint + io.Copy(buf, r.Body) + assert.Equal(t, "foo", buf.String()) + + // nolint + w.Write([]byte("bar")) + }, + )) + defer upstream.Close() + + proxy := NewReverseProxy(config.ListenerConfig{ + EndpointID: "my-endpoint", + Addr: upstream.URL, + }, log.NewNopLogger()) + + b := bytes.NewReader([]byte("foo")) + r := httptest.NewRequest(http.MethodGet, "/foo/bar?a=b", b) + + w := httptest.NewRecorder() + proxy.ServeHTTP(w, r) + + resp := w.Result() + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + buf := new(strings.Builder) + // nolint + io.Copy(buf, resp.Body) + assert.Equal(t, "bar", buf.String()) + }) + + t.Run("timeout", func(t *testing.T) { + blockCh := make(chan struct{}) + upstream := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + <-blockCh + }, + )) + defer upstream.Close() + defer close(blockCh) + + proxy := NewReverseProxy(config.ListenerConfig{ + EndpointID: "my-endpoint", + Addr: upstream.URL, + Timeout: time.Millisecond * 1, + }, log.NewNopLogger()) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + + w := httptest.NewRecorder() + proxy.ServeHTTP(w, r) + + resp := w.Result() + defer resp.Body.Close() + + assert.Equal(t, http.StatusGatewayTimeout, resp.StatusCode) + + m := errorMessage{} + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) + assert.Equal(t, "upstream timeout", m.Error) + }) + + t.Run("upstream unreachable", func(t *testing.T) { + proxy := NewReverseProxy(config.ListenerConfig{ + EndpointID: "my-endpoint", + Addr: "localhost:55555", + }, log.NewNopLogger()) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + + w := httptest.NewRecorder() + proxy.ServeHTTP(w, r) + + resp := w.Result() + defer resp.Body.Close() + + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) + + m := errorMessage{} + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) + assert.Equal(t, "upstream unreachable", m.Error) + }) +} diff --git a/serverv2/reverseproxy/server.go b/agent/reverseproxy/server.go similarity index 61% rename from serverv2/reverseproxy/server.go rename to agent/reverseproxy/server.go index 25e31d6..f48a1fa 100644 --- a/serverv2/reverseproxy/server.go +++ b/agent/reverseproxy/server.go @@ -6,14 +6,17 @@ import ( "net" "net/http" + "github.com/andydunstall/piko/agent/config" "github.com/andydunstall/piko/pkg/log" + "github.com/andydunstall/piko/pkg/middleware" "github.com/gin-gonic/gin" + "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) type Server struct { - handler *ReverseProxy + proxy *ReverseProxy router *gin.Engine @@ -23,15 +26,17 @@ type Server struct { } func NewServer( - upstreams UpstreamManager, + conf config.ListenerConfig, + registry *prometheus.Registry, logger log.Logger, ) *Server { logger = logger.WithSubsystem("reverseproxy") + logger = logger.With(zap.String("endpoint-id", conf.EndpointID)) router := gin.New() - server := &Server{ - handler: NewReverseProxy(upstreams, logger), - router: router, + s := &Server{ + proxy: NewReverseProxy(conf, logger), + router: router, httpServer: &http.Server{ Handler: router, ErrorLog: logger.StdLogger(zapcore.WarnLevel), @@ -40,17 +45,24 @@ func NewServer( } // Recover from panics. - server.router.Use(gin.CustomRecoveryWithWriter(nil, server.panicRoute)) + s.router.Use(gin.CustomRecoveryWithWriter(nil, s.panicRoute)) - server.router.Use(NewLoggerMiddleware(true, logger)) + s.router.Use(middleware.NewLogger(conf.AccessLog, logger)) - server.registerRoutes() + metrics := middleware.NewMetrics("agent") + if registry != nil { + metrics.Register(registry) + } + router.Use(metrics.Handler()) + + s.router.NoRoute(s.proxyRoute) - return server + return s } func (s *Server) Serve(ln net.Listener) error { - s.logger.Info("starting http server", zap.String("addr", ln.Addr().String())) + s.logger.Info("starting reverse proxy") + if err := s.httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { return fmt.Errorf("http serve: %w", err) } @@ -61,18 +73,8 @@ func (s *Server) Shutdown(ctx context.Context) error { return s.httpServer.Shutdown(ctx) } -func (s *Server) registerRoutes() { - // Handle not found routes, which includes all proxied endpoints. - s.router.NoRoute(s.notFoundRoute) -} - -// proxyRoute handles proxied requests from proxy clients. func (s *Server) proxyRoute(c *gin.Context) { - s.handler.ServeHTTP(c.Writer, c.Request) -} - -func (s *Server) notFoundRoute(c *gin.Context) { - s.proxyRoute(c) + s.proxy.ServeHTTP(c.Writer, c.Request) } func (s *Server) panicRoute(c *gin.Context, err any) { diff --git a/agent/rpcserver.go b/agent/rpcserver.go deleted file mode 100644 index 05f962a..0000000 --- a/agent/rpcserver.go +++ /dev/null @@ -1,121 +0,0 @@ -package agent - -import ( - "bufio" - "bytes" - "encoding/json" - "errors" - "io" - "net/http" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/pkg/rpc" - "go.uber.org/zap" -) - -type rpcServer struct { - endpoint *Endpoint - - rpcHandler *rpc.Handler - - logger log.Logger -} - -func newRPCServer(endpoint *Endpoint, logger log.Logger) *rpcServer { - server := &rpcServer{ - endpoint: endpoint, - rpcHandler: rpc.NewHandler(), - logger: logger.WithSubsystem("rpc.server"), - } - server.rpcHandler.Register(rpc.TypeHeartbeat, server.Heartbeat) - server.rpcHandler.Register(rpc.TypeProxyHTTP, server.ProxyHTTP) - return server -} - -func (s *rpcServer) Handler() *rpc.Handler { - return s.rpcHandler -} - -func (s *rpcServer) Heartbeat(b []byte) []byte { - // Echo any received payload. - s.logger.Debug("heartbeat rpc") - return b -} - -func (s *rpcServer) ProxyHTTP(b []byte) []byte { - s.logger.Debug("proxy http rpc") - - var httpResp *http.Response - - httpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(b))) - if err != nil { - s.logger.Error("proxy http rpc; failed to decode http request", zap.Error(err)) - - httpResp = errorResponse( - http.StatusInternalServerError, - "internal error", - ) - } else { - httpResp = s.proxyHTTP(httpReq) - } - - defer httpResp.Body.Close() - - var buffer bytes.Buffer - if err := httpResp.Write(&buffer); err != nil { - s.logger.Error("proxy http rpc; failed to encode http response", zap.Error(err)) - return nil - } - - // TODO(andydunstall): Add header for internal errors. - - s.logger.Debug("proxy http rpc; ok", zap.String("path", httpReq.URL.Path)) - - return buffer.Bytes() -} - -func (s *rpcServer) proxyHTTP(r *http.Request) *http.Response { - s.logger.Debug("proxy http rpc") - - httpResp, err := s.endpoint.ProxyHTTP(r) - if err != nil { - if errors.Is(err, errUpstreamTimeout) { - s.logger.Warn("proxy http rpc; upstream timeout", zap.Error(err)) - - return errorResponse( - http.StatusGatewayTimeout, - "upstream timeout", - ) - } else if errors.Is(err, errUpstreamUnreachable) { - s.logger.Warn("proxy http rpc; upstream unreachable", zap.Error(err)) - - return errorResponse( - http.StatusServiceUnavailable, - "upstream unreachable", - ) - } else { - s.logger.Error("proxy http rpc; internal error", zap.Error(err)) - - return errorResponse( - http.StatusInternalServerError, - "internal error", - ) - } - } - return httpResp -} - -type errorMessage struct { - Error string `json:"error"` -} - -func errorResponse(statusCode int, message string) *http.Response { - m := &errorMessage{ - Error: message, - } - b, _ := json.Marshal(m) - return &http.Response{ - StatusCode: statusCode, - Body: io.NopCloser(bytes.NewReader(b)), - } -} diff --git a/agent/server/server.go b/agent/server/server.go new file mode 100644 index 0000000..2d18db2 --- /dev/null +++ b/agent/server/server.go @@ -0,0 +1,89 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/http" + + "github.com/andydunstall/piko/pkg/log" + "github.com/gin-gonic/gin" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// Server is an agent server to inspect the status of the agent. +type Server struct { + registry *prometheus.Registry + + httpServer *http.Server + + logger log.Logger +} + +func NewServer(registry *prometheus.Registry, logger log.Logger) *Server { + logger = logger.WithSubsystem("server") + + router := gin.New() + server := &Server{ + registry: registry, + httpServer: &http.Server{ + Handler: router, + ErrorLog: logger.StdLogger(zapcore.WarnLevel), + }, + logger: logger, + } + + // Recover from panics. + router.Use(gin.CustomRecoveryWithWriter(nil, server.panicRoute)) + + server.registerRoutes(router) + + return server +} + +func (s *Server) Serve(ln net.Listener) error { + s.logger.Info( + "starting http server", + zap.String("addr", ln.Addr().String()), + ) + + if err := s.httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { + return fmt.Errorf("http serve: %w", err) + } + + return nil +} + +// Shutdown attempts to gracefully shutdown the server by waiting for pending +// requests to complete. +func (s *Server) Shutdown(ctx context.Context) error { + return s.httpServer.Shutdown(ctx) +} + +func (s *Server) registerRoutes(router *gin.Engine) { + if s.registry != nil { + router.GET("/metrics", s.metricsHandler()) + } +} + +func (s *Server) panicRoute(c *gin.Context, err any) { + s.logger.Error( + "handler panic", + zap.String("path", c.FullPath()), + zap.Any("err", err), + ) + c.AbortWithStatus(http.StatusInternalServerError) +} + +func (s *Server) metricsHandler() gin.HandlerFunc { + h := promhttp.HandlerFor( + s.registry, + promhttp.HandlerOpts{Registry: s.registry}, + ) + return func(c *gin.Context) { + h.ServeHTTP(c.Writer, c.Request) + } +} diff --git a/agent/server/server_test.go b/agent/server/server_test.go new file mode 100644 index 0000000..0640c3a --- /dev/null +++ b/agent/server/server_test.go @@ -0,0 +1,46 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/http" + "testing" + + "github.com/andydunstall/piko/pkg/log" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServer_AdminRoutes(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + s := NewServer( + prometheus.NewRegistry(), + log.NewNopLogger(), + ) + go func() { + require.NoError(t, s.Serve(ln)) + }() + defer s.Shutdown(context.TODO()) + + t.Run("metrics", func(t *testing.T) { + url := fmt.Sprintf("http://%s/metrics", ln.Addr().String()) + resp, err := http.Get(url) + assert.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("not found", func(t *testing.T) { + url := fmt.Sprintf("http://%s/foo", ln.Addr().String()) + resp, err := http.Get(url) + assert.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + }) +} diff --git a/agentv2/config/config.go b/agentv2/config/config.go deleted file mode 100644 index 7cf4748..0000000 --- a/agentv2/config/config.go +++ /dev/null @@ -1,208 +0,0 @@ -package config - -import ( - "crypto/tls" - "crypto/x509" - "fmt" - "net" - "net/url" - "os" - "strconv" - "time" - - "github.com/andydunstall/piko/pkg/log" - "github.com/spf13/pflag" -) - -type EndpointConfig struct { - // ID is the endpoint ID to register. - ID string `json:"id" yaml:"id"` - - // Addr is the address of the upstream service to forward to. - Addr string `json:"addr" yaml:"addr"` - - // AccessLog indicates whether to log all incoming connections and requests - // for the endpoint. - AccessLog bool `json:"access_log" yaml:"access_log"` -} - -// URL parses the given upstream address into a URL. Return false if the -// address is invalid. -// -// The addr may be either a full URL, a host and port or just a port. -func (c *EndpointConfig) URL() (*url.URL, bool) { - // Port only. - port, err := strconv.Atoi(c.Addr) - if err == nil && port >= 0 && port < 0xffff { - return &url.URL{ - Scheme: "http", - Host: "localhost:" + c.Addr, - }, true - } - - // Host and port. - host, portStr, err := net.SplitHostPort(c.Addr) - if err == nil { - return &url.URL{ - Scheme: "http", - Host: net.JoinHostPort(host, portStr), - }, true - } - - // URL. - u, err := url.Parse(c.Addr) - if err == nil && u.Scheme != "" && u.Host != "" { - return u, true - } - - return nil, false -} - -func (c *EndpointConfig) Validate() error { - if c.ID == "" { - return fmt.Errorf("missing id") - } - if c.Addr == "" { - return fmt.Errorf("missing addr") - } - if _, ok := c.URL(); !ok { - return fmt.Errorf("invalid addr") - } - return nil -} - -type TLSConfig struct { - // RootCAs contains a path to root certificate authorities to validate - // the TLS connection to the Piko server. - // - // Defaults to using the host root CAs. - RootCAs string `json:"root_cas" yaml:"root_cas"` -} - -func (c *TLSConfig) RegisterFlags(fs *pflag.FlagSet) { - fs.StringVar( - &c.RootCAs, - "tls.root-cas", - "", - ` -A path to a certificate PEM file containing root certificiate authorities to -validate the TLS connection to the Piko server. - -Defaults to using the host root CAs.`, - ) -} - -func (c *TLSConfig) Load() (*tls.Config, error) { - if c.RootCAs == "" { - return nil, nil - } - - tlsConfig := &tls.Config{} - - caCert, err := os.ReadFile(c.RootCAs) - if err != nil { - return nil, fmt.Errorf("open root cas: %s: %w", c.RootCAs, err) - } - caCertPool := x509.NewCertPool() - ok := caCertPool.AppendCertsFromPEM(caCert) - if !ok { - return nil, fmt.Errorf("parse root cas: %s: %w", c.RootCAs, err) - } - tlsConfig.RootCAs = caCertPool - - return tlsConfig, nil -} - -type ConnectConfig struct { - Timeout time.Duration `json:"timeout" yaml:"timeout"` - TLS TLSConfig `json:"tls" yaml:"tls"` -} - -func (c *ConnectConfig) Validate() error { - if c.Timeout == 0 { - return fmt.Errorf("missing timeout") - } - return nil -} - -func (c *ConnectConfig) RegisterFlags(fs *pflag.FlagSet) { - fs.DurationVar( - &c.Timeout, - "connect.timeout", - time.Second*30, - ` -Timeout attempting to connect to the Piko server on boot. Note if the agent -is disconnected after the initial connection succeeds it will keep trying to -reconnect.`, - ) - c.TLS.RegisterFlags(fs) -} - -type Config struct { - Endpoints []EndpointConfig `json:"endpoints" yaml:"endpoints"` - - // Token is used to authenticate the agent with the server. - Token string `json:"token" yaml:"token"` - - Connect ConnectConfig `json:"connect" yaml:"connect"` - - Log log.Config `json:"log" yaml:"log"` - - // GracePeriod is the duration to gracefully shutdown the agent. During - // the grace period, listeners and idle connections are closed, then waits - // for active requests to complete and closes their connections. - GracePeriod time.Duration `json:"grace_period" yaml:"grace_period"` -} - -func (c *Config) Validate() error { - // Note don't validate the number of endpoints, as some commands don't - // require any. - for _, e := range c.Endpoints { - if err := e.Validate(); err != nil { - if e.ID != "" { - return fmt.Errorf("endpoint: %s: %w", e.ID, err) - } - return fmt.Errorf("endpoint: %w", err) - } - } - - if err := c.Connect.Validate(); err != nil { - return fmt.Errorf("connect: %w", err) - } - - if err := c.Log.Validate(); err != nil { - return fmt.Errorf("log: %w", err) - } - - if c.GracePeriod == 0 { - return fmt.Errorf("missing grace period") - } - - return nil -} - -func (c *Config) RegisterFlags(fs *pflag.FlagSet) { - fs.StringVar( - &c.Token, - "token", - "", - ` -A token to authenticate the connection to Piko.`, - ) - - c.Connect.RegisterFlags(fs) - - c.Log.RegisterFlags(fs) - - fs.DurationVar( - &c.GracePeriod, - "grace-period", - time.Minute, - ` -Maximum duration after a shutdown signal is received (SIGTERM or -SIGINT) to gracefully shutdown the server node before terminating. -This includes handling in-progress HTTP requests, gracefully closing -connections to upstream listeners, announcing to the cluster the node is -leaving...`, - ) -} diff --git a/agentv2/endpoint/endpoint.go b/agentv2/endpoint/endpoint.go deleted file mode 100644 index efa815e..0000000 --- a/agentv2/endpoint/endpoint.go +++ /dev/null @@ -1,95 +0,0 @@ -package endpoint - -import ( - "context" - "fmt" - "net" - "net/http" - - "github.com/andydunstall/piko/agentv2/config" - "github.com/andydunstall/piko/pkg/log" - "github.com/gin-gonic/gin" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -// Endpoint handles connections for the endpoint and forwards traffic the -// upstream listener. -type Endpoint struct { - id string - - handler *Handler - - router *gin.Engine - - httpServer *http.Server - - logger log.Logger -} - -func NewEndpoint(conf config.EndpointConfig, logger log.Logger) *Endpoint { - logger = logger.WithSubsystem("endpoint").With( - zap.String("endpoint-id", conf.ID), - ) - - router := gin.New() - endpoint := &Endpoint{ - id: conf.ID, - handler: NewHandler(conf, logger), - router: router, - httpServer: &http.Server{ - Handler: router, - ErrorLog: logger.StdLogger(zapcore.WarnLevel), - }, - logger: logger, - } - - // Recover from panics. - endpoint.router.Use(gin.CustomRecoveryWithWriter(nil, endpoint.panicRoute)) - - endpoint.router.Use(NewLoggerMiddleware(conf.AccessLog, logger)) - - endpoint.registerRoutes() - - return endpoint -} - -func (e *Endpoint) Serve(ln net.Listener) error { - e.logger.Info("starting endpoint") - if err := e.httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { - return fmt.Errorf("http serve: %w", err) - } - return nil -} - -func (e *Endpoint) Shutdown(ctx context.Context) error { - return e.httpServer.Shutdown(ctx) -} - -func (e *Endpoint) registerRoutes() { - // Handle not found routes, which includes all proxied endpoints. - e.router.NoRoute(e.notFoundRoute) -} - -// proxyRoute handles proxied requests from proxy clients. -func (e *Endpoint) proxyRoute(c *gin.Context) { - e.handler.ServeHTTP(c.Writer, c.Request) -} - -func (e *Endpoint) notFoundRoute(c *gin.Context) { - e.proxyRoute(c) -} - -func (e *Endpoint) panicRoute(c *gin.Context, err any) { - e.logger.Error( - "handler panic", - zap.String("path", c.FullPath()), - zap.Any("err", err), - ) - c.AbortWithStatus(http.StatusInternalServerError) -} - -func init() { - // Disable Gin debug logs. - gin.SetMode(gin.ReleaseMode) -} diff --git a/agentv2/endpoint/endpoint_integration_test.go b/agentv2/endpoint/endpoint_integration_test.go deleted file mode 100644 index 75b373c..0000000 --- a/agentv2/endpoint/endpoint_integration_test.go +++ /dev/null @@ -1,91 +0,0 @@ -//go:build integration - -package endpoint - -import ( - "context" - "errors" - "fmt" - "net" - "net/http" - "net/http/httptest" - "testing" - - "github.com/andydunstall/piko/agentv2/config" - "github.com/andydunstall/piko/pkg/log" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type fakeListener struct { - net.Listener - - endpointID string -} - -func (l *fakeListener) EndpointID() string { - return l.endpointID -} - -func TestEndpoint_Forward(t *testing.T) { - t.Run("ok", func(t *testing.T) { - upstream := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) {}, - )) - defer upstream.Close() - - endpoint := NewEndpoint(config.EndpointConfig{ - ID: "my-endpoint", - Addr: upstream.URL, - }, log.NewNopLogger()) - - tcpLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - pikoLn := &fakeListener{ - Listener: tcpLn, - endpointID: "my-endpoint", - } - - go func() { - if err := endpoint.Serve(pikoLn); err != nil && !errors.Is(err, http.ErrServerClosed) { - require.NoError(t, err) - } - }() - defer endpoint.Shutdown(context.TODO()) - - url := fmt.Sprintf("http://%s/foo/bar", pikoLn.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusOK, resp.StatusCode) - }) - - t.Run("no upstream", func(t *testing.T) { - endpoint := NewEndpoint(config.EndpointConfig{ - ID: "my-endpoint", - Addr: "55555", - }, log.NewNopLogger()) - - tcpLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - pikoLn := &fakeListener{ - Listener: tcpLn, - endpointID: "my-endpoint", - } - - go func() { - if err := endpoint.Serve(pikoLn); err != nil && !errors.Is(err, http.ErrServerClosed) { - require.NoError(t, err) - } - }() - defer endpoint.Shutdown(context.TODO()) - - url := fmt.Sprintf("http://%s/foo/bar", pikoLn.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusBadGateway, resp.StatusCode) - }) -} diff --git a/agentv2/endpoint/handler.go b/agentv2/endpoint/handler.go deleted file mode 100644 index 66931d9..0000000 --- a/agentv2/endpoint/handler.go +++ /dev/null @@ -1,43 +0,0 @@ -package endpoint - -import ( - "net/http" - "net/http/httputil" - - "github.com/andydunstall/piko/agentv2/config" - "github.com/andydunstall/piko/pkg/log" - "go.uber.org/zap/zapcore" -) - -// Handler implements a reverse proxy HTTP handler that accepts requests from -// downstream clients and forwards them to upstream services. -type Handler struct { - proxy *httputil.ReverseProxy - - logger log.Logger -} - -func NewHandler(conf config.EndpointConfig, logger log.Logger) *Handler { - logger = logger.WithSubsystem("endpoint.reverseproxy") - - u, ok := conf.URL() - if !ok { - // We've already verified the address on boot so don't need to handle - // the error. - panic("invalid endpoint addr: " + conf.Addr) - } - - return &Handler{ - proxy: &httputil.ReverseProxy{ - Rewrite: func(r *httputil.ProxyRequest) { - r.SetURL(u) - }, - ErrorLog: logger.StdLogger(zapcore.WarnLevel), - }, - logger: logger, - } -} - -func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - h.proxy.ServeHTTP(w, r) -} diff --git a/cli/agent/command.go b/cli/agent/command.go index a06acf6..4182776 100644 --- a/cli/agent/command.go +++ b/cli/agent/command.go @@ -6,14 +6,14 @@ import ( "net" "os" "os/signal" - "strings" "syscall" - "github.com/andydunstall/piko/agent" + "github.com/andydunstall/piko/agent/client" "github.com/andydunstall/piko/agent/config" + "github.com/andydunstall/piko/agent/reverseproxy" + "github.com/andydunstall/piko/agent/server" pikoconfig "github.com/andydunstall/piko/pkg/config" "github.com/andydunstall/piko/pkg/log" - adminserver "github.com/andydunstall/piko/server/server/admin" rungroup "github.com/oklog/run" "github.com/prometheus/client_golang/prometheus" "github.com/spf13/cobra" @@ -22,143 +22,135 @@ import ( func NewCommand() *cobra.Command { cmd := &cobra.Command{ - Use: "agent [flags]", - Short: "start the piko agent", - Long: `Start the Piko agent. + Use: "agent [command] [flags]", + Short: "register endpoints and forward requests to your upstream services", + Long: `The Piko agent registers endpoints with Piko, then listens +for connections on those endpoints and forwards them to your upstream services. -The Piko agent is a command line tool that registers endpoints with Piko and -forwards requests to your upstream service. +Such as you may listen on endpoint 'my-endpoint' and forward connections to +your service at 'localhost:3000'. -To register an endpoint, you configure both the endpoint ID and address to -forward requests to (which will typically be on the same host as the agent). -Such as you may register endpoint 'my-endpoint' that forwards requests to -'localhost:4000'. +The agent opens an outbound connection to the Piko server for each listener, +then incoming connections from Piko are multiplexed over that outbound +connection. Therefore the agent never exposes a port. -For each registered endpoint, the agent will open an outbound-only connection -to Piko. This connection is used to receive proxied requests from the server -which are then forwarded to the configured address. - -If multiple upstreams register the same endpoint, Piko load balances requests -for the endpoint among the connected upstreams. +If there are multiple listeners for the same endpoint, Piko load balances +requests the registered listeners. The agent supports both YAML configuration and command line flags. Configure a YAML file using '--config.path'. When enabling '--config.expand-env', Piko will expand environment variables in the loaded YAML configuration. -Endpoints can be configured either using the YAML configuration or as command -line arguments. When using command line arguments, each endpoint has format -'/', such as 'my-endpoint-123/localhost:3000'. -For more advanced endpoint configurations use the YAML configuration. - Examples: - # Register an endpoint with ID 'my-endpoint-123' that forwards requests to - # to 'localhost:3000'. - piko agent my-endpoint-123/localhost:3000 - - # Register multiple endpoints. - piko agent my-endpoint-1/localhost:3000 my-endpoint-2/localhost:6000 + # Listen for connections from endpoint 'my-endpoint' and forward connections + # to localhost:3000. + piko agent http my-endpoint 3000 - # Specify the Piko server address. - piko agent my-endpoint-123/localhost:3000 --server.url https://piko.example.com - - # Load configuration from a YAML file. - piko agent --config.path ./agent.yaml + # Start all listeners configured in agent.yaml. + piko agent start --config.file ./agent.yaml `, } var conf config.Config - - var configPath string - cmd.Flags().StringVar( - &configPath, - "config.path", - "", - ` -YAML config file path.`, - ) - - var configExpandEnv bool - cmd.Flags().BoolVar( - &configExpandEnv, - "config.expand-env", - false, - ` -Whether to expand environment variables in the config file. - -This will replaces references to ${VAR} or $VAR with the corresponding -environment variable. The replacement is case-sensitive. - -References to undefined variables will be replaced with an empty string. A -default value can be given using form ${VAR:default}.`, - ) + var loadConf pikoconfig.Config // Register flags and set default values. - conf.RegisterFlags(cmd.Flags()) + conf.RegisterFlags(cmd.PersistentFlags()) + loadConf.RegisterFlags(cmd.PersistentFlags()) - cmd.Run = func(cmd *cobra.Command, args []string) { - if configPath != "" { - if err := pikoconfig.Load(configPath, &conf, configExpandEnv); err != nil { - fmt.Printf("load config: %s\n", err.Error()) - os.Exit(1) - } - } - - for _, arg := range args { - elems := strings.Split(arg, "/") - if len(elems) != 2 { - fmt.Printf("invalid endpoint: %s\n", arg) - os.Exit(1) - } - - conf.Endpoints = append(conf.Endpoints, config.EndpointConfig{ - ID: elems[0], - Addr: elems[1], - }) - } - - if err := conf.Validate(); err != nil { - fmt.Printf("invalid config: %s\n", err.Error()) + cmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + if err := loadConf.Load(&conf); err != nil { + fmt.Println(err.Error()) os.Exit(1) } - logger, err := log.NewLogger(conf.Log.Level, conf.Log.Subsystems) - if err != nil { - fmt.Printf("failed to setup logger: %s\n", err.Error()) - os.Exit(1) - } - - if err := run(&conf, logger); err != nil { - logger.Error("failed to run agent", zap.Error(err)) + if err := conf.Validate(); err != nil { + fmt.Printf("config: %s\n", err.Error()) os.Exit(1) } } + cmd.AddCommand(newStartCommand(&conf)) + cmd.AddCommand(newHTTPCommand(&conf)) + return cmd } -func run(conf *config.Config, logger log.Logger) error { - logger.Info("starting piko agent", zap.Any("conf", conf)) - - registry := prometheus.NewRegistry() +func runAgent(conf *config.Config, logger log.Logger) error { + logger.Info("starting piko agent") + logger.Debug("piko config", zap.Any("config", conf)) - adminLn, err := net.Listen("tcp", conf.Admin.BindAddr) + connectTLSConfig, err := conf.Connect.TLS.Load() if err != nil { - return fmt.Errorf("admin listen: %s: %w", conf.Admin.BindAddr, err) + return fmt.Errorf("connect tls: %w", err) } - adminServer := adminserver.NewServer( - adminLn, - nil, - nil, - registry, - logger, + + client := client.New( + client.WithToken(conf.Connect.Token), + client.WithURL(conf.Connect.URL), + client.WithTLSConfig(connectTLSConfig), + client.WithLogger(logger.WithSubsystem("client")), ) - endpointTLSConfig, err := conf.TLS.Load() + registry := prometheus.NewRegistry() + + var group rungroup.Group + + for _, listenerConfig := range conf.Listeners { + connectCtx, connectCancel := context.WithTimeout( + context.Background(), + conf.Connect.Timeout, + ) + defer connectCancel() + + ln, err := client.Listen(connectCtx, listenerConfig.EndpointID) + if err != nil { + return fmt.Errorf("listen: %s: %w", listenerConfig.EndpointID, err) + } + defer ln.Close() + + server := reverseproxy.NewServer(listenerConfig, registry, logger) + + // Listener handler. + group.Add(func() error { + if err := server.Serve(ln); err != nil { + return fmt.Errorf("serve: %w", err) + } + return nil + }, func(error) { + shutdownCtx, cancel := context.WithTimeout( + context.Background(), conf.GracePeriod, + ) + defer cancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + logger.Warn("failed to gracefully shutdown listener", zap.Error(err)) + } + }) + } + + // Agent server. + serverLn, err := net.Listen("tcp", conf.Server.BindAddr) if err != nil { - return fmt.Errorf("tls: %w", err) + return fmt.Errorf("server listen: %s: %w", conf.Server.BindAddr, err) } + server := server.NewServer(registry, logger) - var group rungroup.Group + group.Add(func() error { + if err := server.Serve(serverLn); err != nil { + return fmt.Errorf("agent server: %w", err) + } + return nil + }, func(error) { + shutdownCtx, cancel := context.WithTimeout( + context.Background(), conf.GracePeriod, + ) + defer cancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + logger.Warn("failed to gracefully shutdown agent server", zap.Error(err)) + } + }) // Termination handler. signalCtx, signalCancel := context.WithCancel(context.Background()) @@ -179,45 +171,5 @@ func run(conf *config.Config, logger log.Logger) error { signalCancel() }) - // Endpoints. - metrics := agent.NewMetrics() - metrics.Register(registry) - - for _, e := range conf.Endpoints { - endpoint := agent.NewEndpoint( - e.ID, e.Addr, conf, endpointTLSConfig, metrics, logger, - ) - - endpointCtx, endpointCancel := context.WithCancel(context.Background()) - group.Add(func() error { - if err := endpoint.Run(endpointCtx); err != nil { - return fmt.Errorf("endpoint: %s: %w", e.ID, err) - } - return nil - }, func(error) { - endpointCancel() - }) - } - - // Admin server. - group.Add(func() error { - if err := adminServer.Serve(); err != nil { - return fmt.Errorf("admin server serve: %w", err) - } - return nil - }, func(error) { - if err := adminServer.Close(); err != nil { - logger.Warn("failed to close server", zap.Error(err)) - } - - logger.Info("admin server shut down") - }) - - if err := group.Run(); err != nil { - return err - } - - logger.Info("shutdown complete") - - return nil + return group.Run() } diff --git a/cli/agent/http.go b/cli/agent/http.go new file mode 100644 index 0000000..171f9d9 --- /dev/null +++ b/cli/agent/http.go @@ -0,0 +1,83 @@ +package agent + +import ( + "fmt" + "os" + "time" + + "github.com/andydunstall/piko/agent/config" + "github.com/andydunstall/piko/pkg/log" + "github.com/spf13/cobra" + "go.uber.org/zap" +) + +func newHTTPCommand(conf *config.Config) *cobra.Command { + cmd := &cobra.Command{ + Use: "http [endpoint] [addr] [flags]", + Args: cobra.ExactArgs(2), + Short: "register a http listener", + Long: `Listens for HTTP traffic on the given endpoint and forwards +incoming connections to your upstream service. + +The configured upstream address be a port, host and port or a full URL. + +Examples: + # Listen for connections from endpoint 'my-endpoint' and forward connections + # to localhost:3000. + piko agent http my-endpoint 3000 + + # Listen and forward to 10.26.104.56:3000. + piko agent http my-endpoint 10.26.104.56:3000 + + # Listen and forward to 10.26.104.56:3000 using HTTPS. + piko agent http my-endpoint https://10.26.104.56:3000 +`, + } + + var accessLog bool + cmd.Flags().BoolVar( + &accessLog, + "access-log", + true, + ` +Whether to log all incoming HTTP requests and responses as 'info' logs.`, + ) + + var timeout time.Duration + cmd.Flags().DurationVar( + &timeout, + "timeout", + time.Second*10, + ` +Timeout forwarding incoming HTTP requests to the upstream.`, + ) + + var logger log.Logger + + cmd.PreRun = func(cmd *cobra.Command, args []string) { + // Discard any listeners in the configuration file and use from command + // line. + conf.Listeners = []config.ListenerConfig{{ + EndpointID: args[0], + Addr: args[1], + AccessLog: accessLog, + Timeout: timeout, + }} + + var err error + logger, err = log.NewLogger(conf.Log.Level, conf.Log.Subsystems) + if err != nil { + fmt.Printf("failed to setup logger: %s\n", err.Error()) + os.Exit(1) + } + } + + cmd.Run = func(cmd *cobra.Command, args []string) { + if err := runAgent(conf, logger); err != nil { + logger.Error("failed to run agent", zap.Error(err)) + os.Exit(1) + } + } + + return cmd +} diff --git a/cli/agent/start.go b/cli/agent/start.go new file mode 100644 index 0000000..5bd9830 --- /dev/null +++ b/cli/agent/start.go @@ -0,0 +1,50 @@ +package agent + +import ( + "fmt" + "os" + + "github.com/andydunstall/piko/agent/config" + "github.com/andydunstall/piko/pkg/log" + "github.com/spf13/cobra" + "go.uber.org/zap" +) + +func newStartCommand(conf *config.Config) *cobra.Command { + cmd := &cobra.Command{ + Use: "start [flags]", + Short: "register the configured listeners", + Long: `Registers the configured listeners with Piko and forwards +incoming connections for each listener to your upstream services. + +Examples: + # Start all listeners configured in agent.yaml. + piko agent start --config.file ./agent.yaml +`, + } + + var logger log.Logger + + cmd.PreRun = func(cmd *cobra.Command, args []string) { + var err error + logger, err = log.NewLogger(conf.Log.Level, conf.Log.Subsystems) + if err != nil { + fmt.Printf("failed to setup logger: %s\n", err.Error()) + os.Exit(1) + } + + if len(conf.Listeners) == 0 { + fmt.Printf("no listeners configured\n") + os.Exit(1) + } + } + + cmd.Run = func(cmd *cobra.Command, args []string) { + if err := runAgent(conf, logger); err != nil { + logger.Error("failed to run agent", zap.Error(err)) + os.Exit(1) + } + } + + return cmd +} diff --git a/cli/agentv2/command.go b/cli/agentv2/command.go deleted file mode 100644 index 1c79ec3..0000000 --- a/cli/agentv2/command.go +++ /dev/null @@ -1,85 +0,0 @@ -package agent - -import ( - "fmt" - "os" - - "github.com/andydunstall/piko/agentv2/config" - pikoconfig "github.com/andydunstall/piko/pkg/config" - "github.com/spf13/cobra" -) - -func NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "agentv2 [command] [flags]", - Short: "piko agent", - Long: `The Piko agent registers endpoints with Piko then forwards -incoming connections for each endpoint to your upstream services. - -The agent opens a single outbound connection to the Piko server, which is used -to proxy connections and requests. Therefore the agent never exposes a port. - -The agent supports both YAML configuration and command line flags. Configure -a YAML file using '--config.path'. When enabling '--config.expand-env', Piko -will expand environment variables in the loaded YAML configuration. - -WARNING: Agent V2 is still in development... - -Examples: - # Register HTTP endpoint 'my-endpoint' for forward to localhost:3000. - piko agent http my-endpoint 3000 - - # Start all configured endpoints. - piko agent start --config.file ./agent.yaml -`, - // TODO(andydunstall): Hide while in development. - Hidden: true, - } - - var configPath string - cmd.PersistentFlags().StringVar( - &configPath, - "config.path", - "", - ` -YAML config file path.`, - ) - - var configExpandEnv bool - cmd.PersistentFlags().BoolVar( - &configExpandEnv, - "config.expand-env", - false, - ` -Whether to expand environment variables in the config file. - -This will replaces references to ${VAR} or $VAR with the corresponding -environment variable. The replacement is case-sensitive. - -References to undefined variables will be replaced with an empty string. A -default value can be given using form ${VAR:default}.`, - ) - - conf := &config.Config{} - conf.RegisterFlags(cmd.PersistentFlags()) - - cmd.AddCommand(newStartCommand(conf)) - cmd.AddCommand(newHTTPCommand(conf)) - - // Load the configuration but don't yet validate. - cmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { - if configPath != "" { - if err := pikoconfig.Load(configPath, &conf, configExpandEnv); err != nil { - fmt.Printf("load config: %s\n", err.Error()) - os.Exit(1) - } - } - - if err := conf.Validate(); err != nil { - fmt.Printf("config: %s\n", err.Error()) - os.Exit(1) - } - } - - return cmd -} diff --git a/cli/agentv2/http.go b/cli/agentv2/http.go deleted file mode 100644 index 69a31a3..0000000 --- a/cli/agentv2/http.go +++ /dev/null @@ -1,152 +0,0 @@ -package agent - -import ( - "context" - "fmt" - "os" - "os/signal" - "syscall" - - piko "github.com/andydunstall/piko/agentv2/client" - "github.com/andydunstall/piko/agentv2/config" - "github.com/andydunstall/piko/agentv2/endpoint" - "github.com/andydunstall/piko/pkg/log" - rungroup "github.com/oklog/run" - "github.com/spf13/cobra" - "go.uber.org/zap" -) - -func newHTTPCommand(conf *config.Config) *cobra.Command { - cmd := &cobra.Command{ - Use: "http [endpoint] [upstream addr] [flags]", - Short: "register a http listener", - Long: `Registers a HTTP endpoint with the given endpoint ID and -forwards incoming connections to your upstream service. - -The configured upstream address may be a port, a host and port, or a URL. - -Examples: - # Register endpoint 'my-endpoint' for forward incoming connections to - # localhost:3000. - piko agent http my-endpoint 3000 - - # Register and forward to 10.26.104.56:3000. - piko agent http my-endpoint 10.26.104.56:3000 - - # Register and forward to 10.26.104.56:3000 using HTTPS. - piko agent http my-endpoint https://10.26.104.56:3000 -`, - Args: cobra.ExactArgs(2), - } - - var accessLog bool - cmd.Flags().BoolVar( - &accessLog, - "access-log", - true, - ` -Whether to log all incoming HTTP requests and responses as 'info' logs.`, - ) - - var logger log.Logger - - cmd.PreRun = func(cmd *cobra.Command, args []string) { - // Discard any endpoints in the configuration file and use from command - // line. - conf.Endpoints = nil - conf.Endpoints = append(conf.Endpoints, config.EndpointConfig{ - ID: args[0], - Addr: args[1], - AccessLog: accessLog, - }) - - var err error - logger, err = log.NewLogger(conf.Log.Level, conf.Log.Subsystems) - if err != nil { - fmt.Printf("failed to setup logger: %s\n", err.Error()) - os.Exit(1) - } - } - - cmd.Run = func(cmd *cobra.Command, args []string) { - if err := runHTTP(conf, logger); err != nil { - logger.Error("failed to run agent", zap.Error(err)) - os.Exit(1) - } - } - - return cmd -} - -func runHTTP(conf *config.Config, logger log.Logger) error { - logger.Info("starting piko agent") - logger.Warn("piko agent v2 is still in development") - - // We know there is a single endpoint configured. - endpointConfig := conf.Endpoints[0] - endpoint := endpoint.NewEndpoint(endpointConfig, logger) - - connTLSConfig, err := conf.Connect.TLS.Load() - if err != nil { - return fmt.Errorf("tls: %w", err) - } - - client := piko.New( - piko.WithToken(conf.Token), - piko.WithTLSConfig(connTLSConfig), - piko.WithLogger(logger.WithSubsystem("client")), - ) - - connectCtx, connectCancel := context.WithTimeout( - context.Background(), - conf.Connect.Timeout, - ) - defer connectCancel() - - ln, err := client.Listen(connectCtx, endpointConfig.ID) - if err != nil { - return fmt.Errorf("listen: %s: %w", endpointConfig.ID, err) - } - defer ln.Close() - - var group rungroup.Group - - // Endpoint handler. - group.Add(func() error { - if err := endpoint.Serve(ln); err != nil { - return fmt.Errorf("serve: %w", err) - } - return nil - }, func(error) { - shutdownCtx, cancel := context.WithTimeout( - context.Background(), - conf.GracePeriod, - ) - defer cancel() - - if err := endpoint.Shutdown(shutdownCtx); err != nil { - logger.Warn("failed to gracefully shutdown endpoint", zap.Error(err)) - } - }) - - // Termination handler. - signalCtx, signalCancel := context.WithCancel(context.Background()) - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM) - group.Add(func() error { - select { - case sig := <-signalCh: - logger.Info( - "received shutdown signal", - zap.String("signal", sig.String()), - ) - return nil - case <-signalCtx.Done(): - return nil - } - }, func(error) { - signalCancel() - }) - - return group.Run() -} diff --git a/cli/agentv2/start.go b/cli/agentv2/start.go deleted file mode 100644 index 28fb540..0000000 --- a/cli/agentv2/start.go +++ /dev/null @@ -1,144 +0,0 @@ -package agent - -import ( - "context" - "fmt" - "os" - "os/signal" - "slices" - "syscall" - - piko "github.com/andydunstall/piko/agentv2/client" - "github.com/andydunstall/piko/agentv2/config" - "github.com/andydunstall/piko/agentv2/endpoint" - "github.com/andydunstall/piko/pkg/log" - rungroup "github.com/oklog/run" - "github.com/spf13/cobra" - "go.uber.org/zap" -) - -func newStartCommand(conf *config.Config) *cobra.Command { - cmd := &cobra.Command{ - Use: "start [endpoint...] [flags]", - Short: "register the configured endpoints", - Long: `Registers the configured endpoints with Piko then forwards -incoming connections for each endpoint to your upstream services. - -Examples: - # Start all configured endpoints. - piko agent start --config.file ./agent.yaml - - # Start only endpoints 'endpoint-1' and 'endpoint-2'. - piko agent start endpoint-1 endpoint-2 --config.file ./agent.yaml -`, - Args: cobra.MaximumNArgs(1), - } - - var logger log.Logger - - cmd.PreRun = func(cmd *cobra.Command, args []string) { - var err error - logger, err = log.NewLogger(conf.Log.Level, conf.Log.Subsystems) - if err != nil { - fmt.Printf("failed to setup logger: %s\n", err.Error()) - os.Exit(1) - } - - if len(conf.Endpoints) == 0 { - fmt.Printf("no endpoints configured\n") - os.Exit(1) - } - - // Verify the requested endpoints to start are configured. - var endpointIDs []string - for _, endpoint := range conf.Endpoints { - endpointIDs = append(endpointIDs, endpoint.ID) - } - for _, arg := range args { - if !slices.Contains(endpointIDs, arg) { - fmt.Printf("endpoint not found: %s\n", arg) - os.Exit(1) - } - } - } - - cmd.Run = func(cmd *cobra.Command, args []string) { - if err := runStart(args, conf, logger); err != nil { - logger.Error("failed to run agent", zap.Error(err)) - os.Exit(1) - } - } - - return cmd -} - -func runStart(endpoints []string, conf *config.Config, logger log.Logger) error { - logger.Info("starting piko agent") - logger.Warn("piko agent v2 is still in development") - - client := piko.New( - piko.WithToken(conf.Token), - piko.WithLogger(logger.WithSubsystem("client")), - ) - - var group rungroup.Group - - for _, endpointConfig := range conf.Endpoints { - if len(endpoints) != 0 && !slices.Contains(endpoints, endpointConfig.ID) { - continue - } - - connectCtx, connectCancel := context.WithTimeout( - context.Background(), - conf.Connect.Timeout, - ) - defer connectCancel() - - ln, err := client.Listen(connectCtx, endpointConfig.ID) - if err != nil { - return fmt.Errorf("listen: %s: %w", endpointConfig.ID, err) - } - defer ln.Close() - - endpoint := endpoint.NewEndpoint(endpointConfig, logger) - - // Endpoint handler. - group.Add(func() error { - if err := endpoint.Serve(ln); err != nil { - return fmt.Errorf("serve: %w", err) - } - return nil - }, func(error) { - shutdownCtx, cancel := context.WithTimeout( - context.Background(), - conf.GracePeriod, - ) - defer cancel() - - if err := endpoint.Shutdown(shutdownCtx); err != nil { - logger.Warn("failed to gracefully shutdown endpoint", zap.Error(err)) - } - }) - } - - // Termination handler. - signalCtx, signalCancel := context.WithCancel(context.Background()) - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM) - group.Add(func() error { - select { - case sig := <-signalCh: - logger.Info( - "received shutdown signal", - zap.String("signal", sig.String()), - ) - return nil - case <-signalCtx.Done(): - return nil - } - }, func(error) { - signalCancel() - }) - - return group.Run() -} diff --git a/cli/command.go b/cli/command.go index 94b6317..0affc0e 100644 --- a/cli/command.go +++ b/cli/command.go @@ -2,10 +2,7 @@ package cli import ( "github.com/andydunstall/piko/cli/agent" - agentv2 "github.com/andydunstall/piko/cli/agentv2" "github.com/andydunstall/piko/cli/server" - serverv2 "github.com/andydunstall/piko/cli/serverv2" - "github.com/andydunstall/piko/cli/status" "github.com/andydunstall/piko/cli/workload" "github.com/spf13/cobra" ) @@ -33,25 +30,22 @@ Start a server node with: You can also inspect the status of the server using: - $ piko status + $ piko server status To register an upstream service, use the Piko agent. The agent is a lightweight proxy that runs alongside your services. It connects to the Piko server, registers the configured endpoints, then forwards incoming requests to your services. -Such as to register endpoint 'my-endpoint' that forwards incoming requests to -'localhost:4000', use: +Such as to register endpoint 'my-endpoint' to forward connections to your +service at 'localhost:3000': - $ piko agent my-endpoint/localhost:4000 + $ piko agent http my-endpoint 3000 `, } - cmd.AddCommand(agent.NewCommand()) - cmd.AddCommand(agentv2.NewCommand()) cmd.AddCommand(server.NewCommand()) - cmd.AddCommand(serverv2.NewCommand()) - cmd.AddCommand(status.NewCommand()) + cmd.AddCommand(agent.NewCommand()) cmd.AddCommand(workload.NewCommand()) return cmd diff --git a/cli/server/command.go b/cli/server/command.go index c6d786f..e9f27c9 100644 --- a/cli/server/command.go +++ b/cli/server/command.go @@ -3,33 +3,50 @@ package server import ( "context" "fmt" + "net" "os" "os/signal" + "strings" "syscall" + "github.com/andydunstall/piko/cli/server/status" pikoconfig "github.com/andydunstall/piko/pkg/config" "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/server" + "github.com/andydunstall/piko/server/admin" + "github.com/andydunstall/piko/server/auth" + "github.com/andydunstall/piko/server/cluster" "github.com/andydunstall/piko/server/config" + "github.com/andydunstall/piko/server/gossip" + "github.com/andydunstall/piko/server/proxy" + "github.com/andydunstall/piko/server/upstream" + "github.com/golang-jwt/jwt/v5" + "github.com/hashicorp/go-sockaddr" rungroup "github.com/oklog/run" + "github.com/prometheus/client_golang/prometheus" "github.com/spf13/cobra" "go.uber.org/zap" ) func NewCommand() *cobra.Command { cmd := &cobra.Command{ - Use: "server", + Use: "server [flags]", Short: "start a server node", - Long: `Start a server node. + Long: `The Piko server is responsible for routing incoming proxy +requests and connections to upstream services. Upstream services listen for +traffic on a particular endpoint by opening an outbound-only connection to the +server. Piko then routes traffic for each endpoint to an appropriate upstream +connection. -The Piko server is responsible for routing incoming proxy requests to upstream -services. Upstream services open outbound-connections to the server and -register endpoints. Piko will then route incoming requests to the appropriate -upstream service via the upstreams outbound-only connection. +Use '--cluster.join' to run the server as a cluster of nodes, where you can +specify either a list of addresses of existing members, or a domain that +resolves to the addresses of existing members. -Piko may run as a cluster of nodes for fault tolerance and scalability. Use -'--cluster.join' to configure addresses of existing members in the cluster -to join. +The server exposes 4 ports: +- Proxy port: Receives HTTP(S) requests from proxy clients which are routed +to an upstream service +- Upstream port: Accepts connections from upstream services +- Admin port: Exposes metrics and a status API to inspect the server state +- Gossip port: Used for inter-node gossip traffic The server supports both YAML configuration and command line flags. Configure a YAML file using '--config.path'. When enabling '--config.expand-env', Piko @@ -53,70 +70,315 @@ Examples: } var conf config.Config - - var configPath string - cmd.Flags().StringVar( - &configPath, - "config.path", - "", - ` -YAML config file path.`, - ) - - var configExpandEnv bool - cmd.Flags().BoolVar( - &configExpandEnv, - "config.expand-env", - false, - ` -Whether to expand environment variables in the config file. - -This will replaces references to ${VAR} or $VAR with the corresponding -environment variable. The replacement is case-sensitive. - -References to undefined variables will be replaced with an empty string. A -default value can be given using form ${VAR:default}.`, - ) + var loadConf pikoconfig.Config // Register flags and set default values. conf.RegisterFlags(cmd.Flags()) + loadConf.RegisterFlags(cmd.Flags()) - cmd.Run = func(cmd *cobra.Command, args []string) { - if configPath != "" { - if err := pikoconfig.Load(configPath, &conf, configExpandEnv); err != nil { - fmt.Printf("load config: %s: %s\n", configPath, err.Error()) - os.Exit(1) + var logger log.Logger + + cmd.PreRun = func(cmd *cobra.Command, args []string) { + if err := loadConf.Load(&conf); err != nil { + fmt.Println(err.Error()) + os.Exit(1) + } + + if conf.Cluster.NodeID == "" { + nodeID := cluster.GenerateNodeID() + if conf.Cluster.NodeIDPrefix != "" { + nodeID = conf.Cluster.NodeIDPrefix + nodeID } + conf.Cluster.NodeID = nodeID } if err := conf.Validate(); err != nil { - fmt.Printf("invalid config: %s\n", err.Error()) + fmt.Printf("config: %s\n", err.Error()) os.Exit(1) } - logger, err := log.NewLogger(conf.Log.Level, conf.Log.Subsystems) + var err error + logger, err = log.NewLogger(conf.Log.Level, conf.Log.Subsystems) if err != nil { fmt.Printf("failed to setup logger: %s\n", err.Error()) os.Exit(1) } - if err := run(&conf, logger); err != nil { - logger.Error("failed to run server", zap.Error(err)) + if conf.Proxy.AdvertiseAddr == "" { + advertiseAddr, err := advertiseAddrFromBindAddr(conf.Proxy.BindAddr) + if err != nil { + logger.Error("invalid configuration", zap.Error(err)) + os.Exit(1) + } + conf.Proxy.AdvertiseAddr = advertiseAddr + } + + if conf.Admin.AdvertiseAddr == "" { + advertiseAddr, err := advertiseAddrFromBindAddr(conf.Admin.BindAddr) + if err != nil { + logger.Error("invalid configuration", zap.Error(err)) + os.Exit(1) + } + conf.Admin.AdvertiseAddr = advertiseAddr + } + + if conf.Gossip.AdvertiseAddr == "" { + advertiseAddr, err := advertiseAddrFromBindAddr(conf.Gossip.BindAddr) + if err != nil { + logger.Error("invalid configuration", zap.Error(err)) + os.Exit(1) + } + conf.Gossip.AdvertiseAddr = advertiseAddr + } + } + + cmd.Run = func(cmd *cobra.Command, args []string) { + if err := runServer(&conf, logger); err != nil { + logger.Error("failed to run agent", zap.Error(err)) os.Exit(1) } } + cmd.AddCommand(status.NewCommand()) + return cmd } -func run(conf *config.Config, logger log.Logger) error { - server, err := server.NewServer(conf, logger) - if err != nil { - return fmt.Errorf("server: %w", err) +func runServer(conf *config.Config, logger log.Logger) error { + var verifier auth.Verifier + if conf.Auth.AuthEnabled() { + verifierConf := auth.JWTVerifierConfig{ + HMACSecretKey: []byte(conf.Auth.TokenHMACSecretKey), + Audience: conf.Auth.TokenAudience, + Issuer: conf.Auth.TokenIssuer, + } + + if conf.Auth.TokenRSAPublicKey != "" { + rsaPublicKey, err := jwt.ParseRSAPublicKeyFromPEM( + []byte(conf.Auth.TokenRSAPublicKey), + ) + if err != nil { + return fmt.Errorf("parse rsa public key: %w", err) + } + verifierConf.RSAPublicKey = rsaPublicKey + } + if conf.Auth.TokenECDSAPublicKey != "" { + ecdsaPublicKey, err := jwt.ParseECPublicKeyFromPEM( + []byte(conf.Auth.TokenECDSAPublicKey), + ) + if err != nil { + return fmt.Errorf("parse ecdsa public key: %w", err) + } + verifierConf.ECDSAPublicKey = ecdsaPublicKey + } + verifier = auth.NewJWTVerifier(verifierConf) } + logger.Info("starting piko server", zap.String("node-id", conf.Cluster.NodeID)) + logger.Debug("piko config", zap.Any("config", conf)) + + registry := prometheus.NewRegistry() + + clusterState := cluster.NewState(&cluster.Node{ + ID: conf.Cluster.NodeID, + ProxyAddr: conf.Proxy.AdvertiseAddr, + AdminAddr: conf.Admin.AdvertiseAddr, + }, logger) + clusterState.Metrics().Register(registry) + + upstreams := upstream.NewLoadBalancedManager(clusterState) + upstreams.Metrics().Register(registry) + var group rungroup.Group + // Gossip. + + gossipStreamLn, err := net.Listen("tcp", conf.Gossip.BindAddr) + if err != nil { + return fmt.Errorf("gossip listen: %s: %w", conf.Gossip.BindAddr, err) + } + + gossipPacketLn, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: gossipStreamLn.Addr().(*net.TCPAddr).IP, + Port: gossipStreamLn.Addr().(*net.TCPAddr).Port, + }) + if err != nil { + return fmt.Errorf("gossip listen: %s: %w", conf.Gossip.BindAddr, err) + } + + gossiper := gossip.NewGossip( + clusterState, + gossipStreamLn, + gossipPacketLn, + &conf.Gossip, + logger, + ) + defer gossiper.Close() + gossiper.Metrics().Register(registry) + + // Attempt to join an existing cluster. + // + // Note when running on Kubernetes, if this is the first member, as it is + // not yet ready the service DNS record won't resolve so this may fail. + // Therefore we attempt to join though continue booting if join fails. + // Once booted we then attempt to join again with retries. + nodeIDs, err := gossiper.JoinOnBoot(conf.Cluster.Join) + if err != nil { + logger.Warn("failed to join cluster", zap.Error(err)) + } + if len(nodeIDs) > 0 { + logger.Info( + "joined cluster", + zap.Strings("node-ids", nodeIDs), + ) + } + + gossipCtx, gossipCancel := context.WithCancel(context.Background()) + group.Add(func() error { + if len(nodeIDs) == 0 { + nodeIDs, err = gossiper.JoinOnStartup(gossipCtx, conf.Cluster.Join) + if err != nil { + if conf.Cluster.AbortIfJoinFails { + return fmt.Errorf("join on startup: %w", err) + } + logger.Warn("failed to join cluster", zap.Error(err)) + } + if len(nodeIDs) > 0 { + logger.Info( + "joined cluster", + zap.Strings("node-ids", nodeIDs), + ) + } + } + + <-gossipCtx.Done() + + leaveCtx, cancel := context.WithTimeout( + context.Background(), + conf.GracePeriod, + ) + defer cancel() + + // Leave as soon as we receive the shutdown signal to avoid receiving + // forward proxy requests. + if err := gossiper.Leave(leaveCtx); err != nil { + logger.Warn("failed to gracefully leave cluster", zap.Error(err)) + } else { + logger.Info("left cluster") + } + + return nil + }, func(error) { + gossipCancel() + }) + + // Proxy server. + proxyLn, err := net.Listen("tcp", conf.Proxy.BindAddr) + if err != nil { + return fmt.Errorf("proxy listen: %s: %w", conf.Proxy.BindAddr, err) + } + proxyTLSConfig, err := conf.Proxy.TLS.Load() + if err != nil { + return fmt.Errorf("proxy tls: %w", err) + } + proxyServer := proxy.NewServer( + upstreams, + conf.Proxy, + registry, + proxyTLSConfig, + logger, + ) + + group.Add(func() error { + if err := proxyServer.Serve(proxyLn); err != nil { + return fmt.Errorf("proxy server serve: %w", err) + } + return nil + }, func(error) { + shutdownCtx, cancel := context.WithTimeout( + context.Background(), + conf.GracePeriod, + ) + defer cancel() + + if err := proxyServer.Shutdown(shutdownCtx); err != nil { + logger.Warn("failed to gracefully shutdown server", zap.Error(err)) + } + + logger.Info("proxy server shut down") + }) + + // Upstream server. + upstreamLn, err := net.Listen("tcp", conf.Upstream.BindAddr) + if err != nil { + return fmt.Errorf("upstream listen: %s: %w", conf.Upstream.BindAddr, err) + } + upstreamTLSConfig, err := conf.Upstream.TLS.Load() + if err != nil { + return fmt.Errorf("upstream tls: %w", err) + } + upstreamServer := upstream.NewServer( + upstreams, + verifier, + upstreamTLSConfig, + logger, + ) + + group.Add(func() error { + if err := upstreamServer.Serve(upstreamLn); err != nil { + return fmt.Errorf("upstream server serve: %w", err) + } + return nil + }, func(error) { + shutdownCtx, cancel := context.WithTimeout( + context.Background(), + conf.GracePeriod, + ) + defer cancel() + + if err := upstreamServer.Shutdown(shutdownCtx); err != nil { + logger.Warn("failed to gracefully shutdown server", zap.Error(err)) + } + + logger.Info("upstream server shut down") + }) + + // Admin Server. + adminLn, err := net.Listen("tcp", conf.Admin.BindAddr) + if err != nil { + return fmt.Errorf("admin listen: %s: %w", conf.Admin.BindAddr, err) + } + adminTLSConfig, err := conf.Admin.TLS.Load() + if err != nil { + return fmt.Errorf("admin tls: %w", err) + } + adminServer := admin.NewServer( + registry, + adminTLSConfig, + logger, + ) + adminServer.AddStatus("/cluster", cluster.NewStatus(clusterState)) + adminServer.AddStatus("/gossip", gossip.NewStatus(gossiper)) + + group.Add(func() error { + if err := adminServer.Serve(adminLn); err != nil { + return fmt.Errorf("admin server serve: %w", err) + } + return nil + }, func(error) { + shutdownCtx, cancel := context.WithTimeout( + context.Background(), + conf.GracePeriod, + ) + defer cancel() + + if err := adminServer.Shutdown(shutdownCtx); err != nil { + logger.Warn("failed to gracefully shutdown server", zap.Error(err)) + } + + logger.Info("admin server shut down") + }) + // Termination handler. signalCtx, signalCancel := context.WithCancel(context.Background()) signalCh := make(chan os.Signal, 1) @@ -136,12 +398,34 @@ func run(conf *config.Config, logger log.Logger) error { signalCancel() }) - runCtx, runCancel := context.WithCancel(context.Background()) - group.Add(func() error { - return server.Run(runCtx) - }, func(error) { - runCancel() - }) + if err := group.Run(); err != nil { + return err + } + + logger.Info("shutdown complete") + + return nil +} + +func advertiseAddrFromBindAddr(bindAddr string) (string, error) { + if strings.HasPrefix(bindAddr, ":") { + bindAddr = "0.0.0.0" + bindAddr + } - return group.Run() + host, port, err := net.SplitHostPort(bindAddr) + if err != nil { + return "", fmt.Errorf("invalid bind addr: %s: %w", bindAddr, err) + } + + if host == "0.0.0.0" || host == "::" { + ip, err := sockaddr.GetPrivateIP() + if err != nil { + return "", fmt.Errorf("get interface addr: %w", err) + } + if ip == "" { + return "", fmt.Errorf("no private ip found") + } + return ip + ":" + port, nil + } + return bindAddr, nil } diff --git a/cli/status/cluster.go b/cli/server/status/cluster.go similarity index 52% rename from cli/status/cluster.go rename to cli/server/status/cluster.go index a4b26dd..7c3d5b6 100644 --- a/cli/status/cluster.go +++ b/cli/server/status/cluster.go @@ -2,30 +2,28 @@ package status import ( "fmt" - "net/url" "os" "sort" "github.com/andydunstall/piko/server/cluster" - "github.com/andydunstall/piko/status/client" - "github.com/andydunstall/piko/status/config" + "github.com/andydunstall/piko/server/status/client" yaml "github.com/goccy/go-yaml" "github.com/spf13/cobra" ) -func newClusterCommand() *cobra.Command { +func newClusterCommand(c *client.Client) *cobra.Command { cmd := &cobra.Command{ Use: "cluster", Short: "inspect proxy cluster", } - cmd.AddCommand(newClusterNodesCommand()) - cmd.AddCommand(newClusterNodeCommand()) + cmd.AddCommand(newClusterNodesCommand(c)) + cmd.AddCommand(newClusterNodeCommand(c)) return cmd } -func newClusterNodesCommand() *cobra.Command { +func newClusterNodesCommand(c *client.Client) *cobra.Command { cmd := &cobra.Command{ Use: "nodes", Short: "inspect cluster nodes", @@ -35,20 +33,12 @@ Queries the server for the set of nodes the cluster that this node knows about. The output contains the state of each known node. Examples: - piko status cluster nodes + piko server status cluster nodes `, } - var conf config.Config - conf.RegisterFlags(cmd.Flags()) - cmd.Run = func(cmd *cobra.Command, args []string) { - if err := conf.Validate(); err != nil { - fmt.Printf("invalid config: %s\n", err.Error()) - os.Exit(1) - } - - showClusterNodes(&conf) + showClusterNodes(c) } return cmd @@ -58,13 +48,10 @@ type clusterNodesOutput struct { Nodes []*cluster.NodeMetadata `json:"nodes"` } -func showClusterNodes(conf *config.Config) { - // The URL has already been validated in conf. - url, _ := url.Parse(conf.Server.URL) - client := client.NewClient(url, conf.Forward) - defer client.Close() +func showClusterNodes(c *client.Client) { + cluster := client.NewCluster(c) - nodes, err := client.ClusterNodes() + nodes, err := cluster.Nodes() if err != nil { fmt.Printf("failed to get cluster nodes: %s\n", err.Error()) os.Exit(1) @@ -79,10 +66,10 @@ func showClusterNodes(conf *config.Config) { Nodes: nodes, } b, _ := yaml.Marshal(output) - fmt.Println(string(b)) + fmt.Print(string(b)) } -func newClusterNodeCommand() *cobra.Command { +func newClusterNodeCommand(c *client.Client) *cobra.Command { cmd := &cobra.Command{ Use: "node", Args: cobra.ExactArgs(1), @@ -94,40 +81,29 @@ a node ID of 'local' to query the local node. Examples: # Inspect node bbc69214. - piko status cluster node bbc69214 + piko server status cluster node bbc69214 # Inspect local node. - piko status cluster node local + piko server status cluster node local `, } - var conf config.Config - conf.RegisterFlags(cmd.Flags()) - cmd.Run = func(cmd *cobra.Command, args []string) { - if err := conf.Validate(); err != nil { - fmt.Printf("invalid config: %s\n", err.Error()) - os.Exit(1) - } - - showClusterNode(args[0], &conf) + showClusterNode(args[0], c) } return cmd } -func showClusterNode(nodeID string, conf *config.Config) { - // The URL has already been validated in conf. - url, _ := url.Parse(conf.Server.URL) - client := client.NewClient(url, conf.Forward) - defer client.Close() +func showClusterNode(nodeID string, c *client.Client) { + cluster := client.NewCluster(c) - node, err := client.ClusterNode(nodeID) + node, err := cluster.Node(nodeID) if err != nil { fmt.Printf("failed to get cluster nodes: %s: %s\n", nodeID, err.Error()) os.Exit(1) } b, _ := yaml.Marshal(node) - fmt.Println(string(b)) + fmt.Print(string(b)) } diff --git a/cli/server/status/command.go b/cli/server/status/command.go new file mode 100644 index 0000000..834a606 --- /dev/null +++ b/cli/server/status/command.go @@ -0,0 +1,52 @@ +package status + +import ( + "fmt" + "net/url" + "os" + + "github.com/andydunstall/piko/server/status/client" + "github.com/andydunstall/piko/server/status/config" + "github.com/spf13/cobra" +) + +func NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "inspect server status", + Long: `Inspect server status. + +Each Piko server exposes a status API to inspect the state of the node, this +can be used to answer questions such as: +* What upstream listeners are attached to each node? +* What cluster state does this node know? +* What is the gossip state of each known node? + +See 'status --help' for the availale commands. + +Examples: + # Inspect the known nodes in the cluster. + piko server status cluster nodes +`, + } + + var conf config.Config + conf.RegisterFlags(cmd.PersistentFlags()) + + c := client.NewClient(nil) + + cmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + if err := conf.Validate(); err != nil { + fmt.Printf("config: %s\n", err.Error()) + os.Exit(1) + } + + url, _ := url.Parse(conf.Server.URL) + c.SetURL(url) + } + + cmd.AddCommand(newClusterCommand(c)) + cmd.AddCommand(newGossipCommand(c)) + + return cmd +} diff --git a/cli/status/gossip.go b/cli/server/status/gossip.go similarity index 51% rename from cli/status/gossip.go rename to cli/server/status/gossip.go index 9bf6e1f..8af8e95 100644 --- a/cli/status/gossip.go +++ b/cli/server/status/gossip.go @@ -2,30 +2,28 @@ package status import ( "fmt" - "net/url" "os" "sort" "github.com/andydunstall/piko/pkg/gossip" - "github.com/andydunstall/piko/status/client" - "github.com/andydunstall/piko/status/config" + "github.com/andydunstall/piko/server/status/client" yaml "github.com/goccy/go-yaml" "github.com/spf13/cobra" ) -func newGossipCommand() *cobra.Command { +func newGossipCommand(c *client.Client) *cobra.Command { cmd := &cobra.Command{ Use: "gossip", Short: "inspect gossip state", } - cmd.AddCommand(newGossipNodesCommand()) - cmd.AddCommand(newGossipNodeCommand()) + cmd.AddCommand(newGossipNodesCommand(c)) + cmd.AddCommand(newGossipNodeCommand(c)) return cmd } -func newGossipNodesCommand() *cobra.Command { +func newGossipNodesCommand(c *client.Client) *cobra.Command { cmd := &cobra.Command{ Use: "nodes", Short: "inspect gossip nodes", @@ -35,20 +33,12 @@ Queries the server for the metadata for each known gossip node in the cluster. Examples: - piko status gossip nodes + piko server status gossip nodes `, } - var conf config.Config - conf.RegisterFlags(cmd.Flags()) - cmd.Run = func(cmd *cobra.Command, args []string) { - if err := conf.Validate(); err != nil { - fmt.Printf("invalid config: %s\n", err.Error()) - os.Exit(1) - } - - showGossipNodes(&conf) + showGossipNodes(c) } return cmd @@ -58,13 +48,10 @@ type gossipNodesOutput struct { Nodes []gossip.NodeMetadata `json:"nodes"` } -func showGossipNodes(conf *config.Config) { - // The URL has already been validated in conf. - url, _ := url.Parse(conf.Server.URL) - client := client.NewClient(url, conf.Forward) - defer client.Close() +func showGossipNodes(c *client.Client) { + gossip := client.NewGossip(c) - nodes, err := client.GossipNodes() + nodes, err := gossip.Nodes() if err != nil { fmt.Printf("failed to get gossip nodes: %s\n", err.Error()) os.Exit(1) @@ -82,7 +69,7 @@ func showGossipNodes(conf *config.Config) { fmt.Println(string(b)) } -func newGossipNodeCommand() *cobra.Command { +func newGossipNodeCommand(c *client.Client) *cobra.Command { cmd := &cobra.Command{ Use: "node", Args: cobra.ExactArgs(1), @@ -92,32 +79,21 @@ func newGossipNodeCommand() *cobra.Command { Queries the server for the known state of the gossip node with the given ID. Examples: - piko status gossip node bbc69214 + piko server status gossip node bbc69214 `, } - var conf config.Config - conf.RegisterFlags(cmd.Flags()) - cmd.Run = func(cmd *cobra.Command, args []string) { - if err := conf.Validate(); err != nil { - fmt.Printf("invalid config: %s\n", err.Error()) - os.Exit(1) - } - - showGossipNode(args[0], &conf) + showGossipNode(args[0], c) } return cmd } -func showGossipNode(nodeID string, conf *config.Config) { - // The URL has already been validated in conf. - url, _ := url.Parse(conf.Server.URL) - client := client.NewClient(url, conf.Forward) - defer client.Close() +func showGossipNode(nodeID string, c *client.Client) { + gossip := client.NewGossip(c) - node, err := client.GossipNode(nodeID) + node, err := gossip.Node(nodeID) if err != nil { fmt.Printf("failed to get gossip node: %s: %s\n", nodeID, err.Error()) os.Exit(1) diff --git a/cli/serverv2/command.go b/cli/serverv2/command.go deleted file mode 100644 index 97e789c..0000000 --- a/cli/serverv2/command.go +++ /dev/null @@ -1,134 +0,0 @@ -package server - -import ( - "context" - "fmt" - "net" - "os" - "os/signal" - "syscall" - "time" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/serverv2/reverseproxy" - "github.com/andydunstall/piko/serverv2/upstream" - rungroup "github.com/oklog/run" - "github.com/spf13/cobra" - "go.uber.org/zap" -) - -func NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "serverv2 [flags]", - Short: "start a server node", - Long: `Start a server node. - -WARNING: Server V2 is still in development... -`, - // TODO(andydunstall): Hide while in development. - Hidden: true, - } - - cmd.Run = func(cmd *cobra.Command, args []string) { - logger, err := log.NewLogger("debug", nil) - if err != nil { - fmt.Printf("failed to setup logger: %s\n", err.Error()) - os.Exit(1) - } - - if err := run(logger); err != nil { - logger.Error("failed to run server", zap.Error(err)) - os.Exit(1) - } - } - - return cmd -} - -func run(logger log.Logger) error { - logger.Info("starting piko server") - logger.Warn("piko server v2 is still in development") - - proxyLn, err := net.Listen("tcp", ":8000") - if err != nil { - return fmt.Errorf("proxy listen: %s: %w", ":8001", err) - } - - upstreamLn, err := net.Listen("tcp", ":8001") - if err != nil { - return fmt.Errorf("upstream listen: %s: %w", ":8001", err) - } - - upstreamManager := upstream.NewManager() - proxyServer := reverseproxy.NewServer(upstreamManager, logger) - upstreamServer := upstream.NewServer(upstreamManager, nil, logger) - - var group rungroup.Group - - // Proxy server. - group.Add(func() error { - if err := proxyServer.Serve(proxyLn); err != nil { - return fmt.Errorf("proxy server serve: %w", err) - } - return nil - }, func(error) { - shutdownCtx, cancel := context.WithTimeout( - context.Background(), - time.Second*10, - ) - defer cancel() - - if err := proxyServer.Shutdown(shutdownCtx); err != nil { - logger.Warn("failed to gracefully shutdown proxy server", zap.Error(err)) - } - - logger.Info("proxy server shut down") - }) - - // Upstream server. - group.Add(func() error { - if err := upstreamServer.Serve(upstreamLn); err != nil { - return fmt.Errorf("upstream server serve: %w", err) - } - return nil - }, func(error) { - shutdownCtx, cancel := context.WithTimeout( - context.Background(), - time.Second*10, - ) - defer cancel() - - if err := upstreamServer.Shutdown(shutdownCtx); err != nil { - logger.Warn("failed to gracefully shutdown upstream server", zap.Error(err)) - } - - logger.Info("upstream server shut down") - }) - - // Termination handler. - signalCtx, signalCancel := context.WithCancel(context.Background()) - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM) - group.Add(func() error { - select { - case sig := <-signalCh: - logger.Info( - "received shutdown signal", - zap.String("signal", sig.String()), - ) - return nil - case <-signalCtx.Done(): - return nil - } - }, func(error) { - signalCancel() - }) - - if err := group.Run(); err != nil { - return err - } - - logger.Info("shutdown complete") - - return nil -} diff --git a/cli/status/command.go b/cli/status/command.go deleted file mode 100644 index d8bffd9..0000000 --- a/cli/status/command.go +++ /dev/null @@ -1,36 +0,0 @@ -package status - -import "github.com/spf13/cobra" - -func NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "status", - Short: "inspect server status", - Long: `Inspect server status. - -Each Piko server exposes a status API to inspect the state of the node, this -can be used to answer questions such as: -* What upstream listeners are attached to each node? -* What cluster state does this node know? -* What is the gossip state of each known node? - -See 'status --help' for the availale commands. - -Examples: - # Inspect the known nodes in the cluster. - piko status cluster nodes - - # Inspect the upstream listeners connected to this node. - piko status proxy endpoints - - # Inspect the status of server 10.26.104.56:8002. - piko status proxy endpoints --server 10.26.104.56:8002 -`, - } - - cmd.AddCommand(newProxyCommand()) - cmd.AddCommand(newClusterCommand()) - cmd.AddCommand(newGossipCommand()) - - return cmd -} diff --git a/cli/status/proxy.go b/cli/status/proxy.go deleted file mode 100644 index 50cd347..0000000 --- a/cli/status/proxy.go +++ /dev/null @@ -1,76 +0,0 @@ -package status - -import ( - "fmt" - "net/url" - "os" - - "github.com/andydunstall/piko/status/client" - "github.com/andydunstall/piko/status/config" - yaml "github.com/goccy/go-yaml" - "github.com/spf13/cobra" -) - -func newProxyCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "proxy", - Short: "inspect proxy status", - } - - cmd.AddCommand(newProxyEndpointsCommand()) - - return cmd -} - -func newProxyEndpointsCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "endpoints", - Short: "inspect proxy endpoints", - Long: `Inspect proxy endpoints. - -Queries the server for the set of endpoints with connected upstream listeners -connected to this node. The output contains the endpoint IDs and set of -listeners for that endpoint. - -Examples: - piko status proxy endpoints -`, - } - - var conf config.Config - conf.RegisterFlags(cmd.Flags()) - - cmd.Run = func(cmd *cobra.Command, args []string) { - if err := conf.Validate(); err != nil { - fmt.Printf("invalid config: %s\n", err.Error()) - os.Exit(1) - } - - showProxyEndpoints(&conf) - } - - return cmd -} - -type proxyEndpointsOutput struct { - Endpoints map[string][]string `json:"endpoints"` -} - -func showProxyEndpoints(conf *config.Config) { - // The URL has already been validated in conf. - url, _ := url.Parse(conf.Server.URL) - client := client.NewClient(url, conf.Forward) - defer client.Close() - - endpoints, err := client.ProxyEndpoints() - if err != nil { - fmt.Printf("failed to get proxy endpoints: %s\n", err.Error()) - os.Exit(1) - } - - output := proxyEndpointsOutput{ - Endpoints: endpoints, - } - b, _ := yaml.Marshal(output) - fmt.Println(string(b)) -} diff --git a/cli/workload/requests.go b/cli/workload/requests.go index 2da8895..6af15b7 100644 --- a/cli/workload/requests.go +++ b/cli/workload/requests.go @@ -12,7 +12,6 @@ import ( "syscall" "time" - pikoconfig "github.com/andydunstall/piko/pkg/config" "github.com/andydunstall/piko/pkg/log" "github.com/andydunstall/piko/workload/config" "github.com/spf13/cobra" @@ -52,41 +51,10 @@ Examples: var conf config.RequestsConfig - var configPath string - cmd.Flags().StringVar( - &configPath, - "config.path", - "", - ` -YAML config file path.`, - ) - - var configExpandEnv bool - cmd.Flags().BoolVar( - &configExpandEnv, - "config.expand-env", - false, - ` -Whether to expand environment variables in the config file. - -This will replaces references to ${VAR} or $VAR with the corresponding -environment variable. The replacement is case-sensitive. - -References to undefined variables will be replaced with an empty string. A -default value can be given using form ${VAR:default}.`, - ) - // Register flags and set default values. conf.RegisterFlags(cmd.Flags()) cmd.Run = func(cmd *cobra.Command, args []string) { - if configPath != "" { - if err := pikoconfig.Load(configPath, &conf, configExpandEnv); err != nil { - fmt.Printf("load config: %s\n", err.Error()) - os.Exit(1) - } - } - if err := conf.Validate(); err != nil { fmt.Printf("invalid config: %s\n", err.Error()) os.Exit(1) diff --git a/cli/workload/upstreams.go b/cli/workload/upstreams.go index 905d37e..1afe2b3 100644 --- a/cli/workload/upstreams.go +++ b/cli/workload/upstreams.go @@ -8,7 +8,6 @@ import ( "strconv" "syscall" - pikoconfig "github.com/andydunstall/piko/pkg/config" "github.com/andydunstall/piko/pkg/log" "github.com/andydunstall/piko/workload/config" "github.com/andydunstall/piko/workload/upstream" @@ -46,41 +45,10 @@ Examples: var conf config.UpstreamsConfig - var configPath string - cmd.Flags().StringVar( - &configPath, - "config.path", - "", - ` -YAML config file path.`, - ) - - var configExpandEnv bool - cmd.Flags().BoolVar( - &configExpandEnv, - "config.expand-env", - false, - ` -Whether to expand environment variables in the config file. - -This will replaces references to ${VAR} or $VAR with the corresponding -environment variable. The replacement is case-sensitive. - -References to undefined variables will be replaced with an empty string. A -default value can be given using form ${VAR:default}.`, - ) - // Register flags and set default values. conf.RegisterFlags(cmd.Flags()) cmd.Run = func(cmd *cobra.Command, args []string) { - if configPath != "" { - if err := pikoconfig.Load(configPath, &conf, configExpandEnv); err != nil { - fmt.Printf("load config: %s\n", err.Error()) - os.Exit(1) - } - } - if err := conf.Validate(); err != nil { fmt.Printf("invalid config: %s\n", err.Error()) os.Exit(1) diff --git a/docs/getting-started.md b/docs/getting-started.md index 673f954..b35649f 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -34,10 +34,10 @@ The following ports are exposed: and health checks You can verify Piko has started and discovered the other nodes in the cluster -by running `piko status cluster nodes`, which will request the set of known -nodes from the Piko admin API (routed to a random node), such as: +by running `piko server status cluster nodes`, which will request the set of +known nodes from the Piko admin API (routed to a random node), such as: ``` -$ piko status cluster nodes +$ piko server status cluster nodes nodes: - id: piko-1-fuvaflv status: active @@ -60,7 +60,7 @@ nodes: ``` You can also use the `--forward` flag to forward the request to a particular -node, such as `piko status cluster nodes --forward piko-3-p3wnt2z`. +node, such as `piko server status cluster nodes --forward piko-3-p3wnt2z`. The cluster also includes Prometheus and Grafana to inspect the cluster metrics. You can open Grafana at `http://localhost:3000`. @@ -77,7 +77,7 @@ First create a local HTTP server to forward requests to, such as Then run the Piko agent and register endpoint `my-endpoint` using: ```shell -piko agent my-endpoint/localhost:4000 +piko agent http my-endpoint 4000 ``` This will connect to the cluster load balancer, which routes the request to @@ -88,10 +88,10 @@ request to your service. See `piko agent -h` for the available options. You can verify the upstream has connected and registered the endpoint by -running `piko status cluster nodes` again, which will now show one of the nodes -has a connected stream and registered endpoint: +running `piko server status cluster nodes` again, which will now show one of +the nodes has a connected stream and registered endpoint: ``` -$ piko status cluster nodes +$ piko server status cluster nodes nodes: - id: piko-1-fuvaflv status: active @@ -103,10 +103,10 @@ nodes: ``` You can also inspect the upstreams connected for that registered endpoint with -`piko status proxy endpoints --forward `. Such as in the above example -the upstream is connected to node `piko-1-fuvaflv`: +`piko server status proxy endpoints --forward `. Such as in the above +example the upstream is connected to node `piko-1-fuvaflv`: ``` -$ piko status proxy endpoints --forward piko-1-fuvaflv +$ piko server status proxy endpoints --forward piko-1-fuvaflv endpoints: my-endpoint: - 172.18.0.7:39084 diff --git a/docs/manage/configure.md b/docs/manage/configure.md index 0d5f4a6..461d142 100644 --- a/docs/manage/configure.md +++ b/docs/manage/configure.md @@ -23,176 +23,175 @@ string. You can also define a default value using form `${VAR:default}`. The Piko server node is run using `piko server`. It has the following configuration: ``` -proxy: - # The host/port to listen for incoming proxy HTTP requests. - # - # If the host is unspecified it defaults to all listeners, such as - # '--proxy.bind-addr :8000' will listen on '0.0.0.0:8000'. - bind_addr: :8000 - - # Proxy listen address to advertise to other nodes in the cluster. This is the - # address other nodes will used to forward proxy requests. - # - # Such as if the listen address is ':8000', the advertised address may be - # '10.26.104.45:8000' or 'node1.cluster:8000'. - # - # By default, if the bind address includes an IP to bind to that will be used. - # If the bind address does not include an IP (such as ':8000') the nodes - # private IP will be used, such as a bind address of ':8000' may have an - # advertise address of '10.26.104.14:8000'. - advertise_addr: "" - - # The timeout when sending proxied requests to upstream listeners for forwarding - # to other nodes in the cluster. - # - # If the upstream does not respond within the given timeout a - # '504 Gateway Timeout' is returned to the client. - gateway_timeout: 15s - - tls: - # Whether to enable TLS on the listener. - # - # If enabled must configure the cert and key. - enabled: false - - # Path to the PEM encoded certificate file. - cert: "" - - # Path to the PEM encoded key file. - key: "" - - http: - # The maximum duration for reading the entire request, including the - # body. A zero or negative value means there will be no timeout. - read_timeout: 10s - - # The maximum duration for reading the request headers. If zero, - # http.read-timeout is used. - read_header_timeout: 10s - - # The maximum duration before timing out writes of the response.`, - write_timeout: 10s - - # The maximum amount of time to wait for the next request when - # keep-alives are enabled. - idle_timeout: 5m0s +cluster: + # A unique identifier for the node in the cluster. + # + # By default a random ID will be generated for the node. + node_id: "" + + # A prefix for the node ID. + # + # Piko will generate a unique random identifier for the node and append it to + # the given prefix. + # + # Such as you could use the node or pod name as a prefix, then add a unique + # identifier to ensure the node ID is unique across restarts. + node_id_prefix: "" + + # A list of addresses of members in the cluster to join. + # + # This may be either addresses of specific nodes, such as + # '--cluster.join 10.26.104.14,10.26.104.75', or a domain that resolves to + # the addresses of the nodes in the cluster (e.g. a Kubernetes headless + # service), such as '--cluster.join piko.prod-piko-ns'. + # + # Each address must include the host, and may optionally include a port. If no + # port is given, the gossip port of this node is used. + # + # Note each node propagates membership information to the other known nodes, + # so the initial set of configured members only needs to be a subset of nodes. + join: [] + + # Whether the server node should abort if it is configured with more than one + # node to join (excluding itself) but fails to join any members. + abort_if_join_fails: true - # The maximum number of bytes the server will read parsing the request - # header's keys and values, including the request line. - max_header_bytes: 1048576 +proxy: + # The host/port to listen for incoming proxy connections. + # + # If the host is unspecified it defaults to all listeners, such as + # '--proxy.bind-addr :8000' will listen on '0.0.0.0:8000'. + bind_addr: ":8000" + + # Proxy to advertise to other nodes in the cluster. This is the + # address other nodes will used to forward proxy connections. + # + # Such as if the listen address is ':8000', the advertised address may be + # '10.26.104.45:8000' or 'node1.cluster:8000'. + # + # By default, if the bind address includes an IP to bind to that will be used. + # If the bind address does not include an IP (such as ':8000') the nodes + # private IP will be used, such as a bind address of ':8000' may have an + # advertise address of '10.26.104.14:8000'. + advertise_addr: "" + + # Timeout when forwarding incoming requests to the upstream. + timeout: 30s + + # Whether to log all incoming connections and requests. + access_log: true + + http: + # The maximum duration for reading the entire request, including the body. A + # zero or negative value means there will be no timeout. + read_timeout: 10s + + # The maximum duration for reading the request headers. If zero, + # http.read-timeout is used. + read_header_timeout: 10s + + # The maximum duration before timing out writes of the response. + write_timeout: 10s + + # The maximum amount of time to wait for the next request when keep-alives are + # enabled. + idle_timeout: 5m0s + + # The maximum number of bytes the server will read parsing the request header's + # keys and values, including the request line. + max_header_bytes: 1048576 + + tls: + # Whether to enable TLS on the listener. + # + # If enabled must configure the cert and key. + enabled: false + + # Path to the PEM encoded certificate file. + cert: "" + + # Path to the PEM encoded key file. + key: "" upstream: - # The host/port to listen for connections from upstream listeners. - # - # If the host is unspecified it defaults to all listeners, such as - # '--upstream.bind-addr :8001' will listen on '0.0.0.0:8001'. - bind_addr: :8001 + # The host/port to listen for incoming upstream connections. + # + # If the host is unspecified it defaults to all listeners, such as + # '--upstream.bind-addr :8001' will listen on '0.0.0.0:8001'. + bind_addr: ":8001" - # Upstream listen address to advertise to other nodes in the cluster. + tls: + # Whether to enable TLS on the listener. # - # Such as if the listen address is ':8001', the advertised address may be - # '10.26.104.45:8001' or 'node1.cluster:8001'. - # - # By default, if the bind address includes an IP to bind to that will be used. - # If the bind address does not include an IP (such as ':8001') the nodes - # private IP will be used, such as a bind address of ':8001' may have an - # advertise address of '10.16.104.14:8001'. - advertise_addr: "" - - tls: - # Whether to enable TLS on the listener. - # - # If enabled must configure the cert and key. - enabled: false - - # Path to the PEM encoded certificate file. - cert: "" - - # Path to the PEM encoded key file. - key: "" + # If enabled must configure the cert and key. + enabled: false -admin: - # The host/port to listen for incoming admin connections. - # - # If the host is unspecified it defaults to all listeners, such as - # '--admin.bind-addr :8002' will listen on '0.0.0.0:8002'. - bind_addr: :8002 + # Path to the PEM encoded certificate file. + cert: "" - # Admin listen address to advertise to other nodes in the cluster. This is the - # address other nodes will used to forward admin requests. - # - # Such as if the listen address is ':8002', the advertised address may be - # '10.26.104.45:8002' or 'node1.cluster:8002'. - # - # By default, if the bind address includes an IP to bind to that will be used. - # If the bind address does not include an IP (such as ':8002') the nodes - # private IP will be used, such as a bind address of ':8002' may have an - # advertise address of '10.26.104.14:8002'. - advertise_addr: "" - - tls: - # Whether to enable TLS on the listener. - # - # If enabled must configure the cert and key. - enabled: false - - # Path to the PEM encoded certificate file. - cert: "" - - # Path to the PEM encoded key file. - key: "" + # Path to the PEM encoded key file. + key: "" gossip: - # The host/port to listen for inter-node gossip traffic. - # - # If the host is unspecified it defaults to all listeners, such as - # '--gossip.bind-addr :8003' will listen on '0.0.0.0:8003'. - bind_addr: :8003 - - # Gossip listen address to advertise to other nodes in the cluster. This is the - # address other nodes will used to gossip with the node. - # - # Such as if the listen address is ':8003', the advertised address may be - # '10.26.104.45:8003' or 'node1.cluster:8003'. - # - # By default, if the bind address includes an IP to bind to that will be used. - # If the bind address does not include an IP (such as ':8003') the nodes - # private IP will be used, such as a bind address of ':8003' may have an - # advertise address of '10.26.104.14:8003'. - advertise_addr: "" + # The host/port to listen for inter-node gossip traffic. + # + # If the host is unspecified it defaults to all listeners, such as + # '--gossip.bind-addr :8003' will listen on '0.0.0.0:8003'. + bind_addr: ":8003" + + # Gossip listen address to advertise to other nodes in the cluster. This is the + # address other nodes will used to gossip with the node. + # + # Such as if the listen address is ':8003', the advertised address may be + # '10.26.104.45:8003' or 'node1.cluster:8003'. + # + # By default, if the bind address includes an IP to bind to that will be used. + # If the bind address does not include an IP (such as ':8003') the nodes + # private IP will be used, such as a bind address of ':8003' may have an + # advertise address of '10.26.104.14:8003'. + advertise_addr: "" + + # The interval to initiate rounds of gossip. + # + # Each gossip round selects another known node to synchronize with.`, + interval: 500ms + + # The maximum size of any packet sent. + # + # Depending on your networks MTU you may be able to increase to include more data + # in each packet. + max_packet_size: 1400 -cluster: - # A unique identifier for the node in the cluster. - # - # By default a random ID will be generated for the node. - node_id: "" - - # A prefix for the node ID. - # - # Piko will generate a unique random identifier for the node and append it to - # the given prefix. - # - # Such as you could use the node or pod name as a prefix, then add a unique - # identifier to ensure the node ID is unique across restarts. - node_id_prefix: "" - - # A list of addresses of members in the cluster to join. - # - # This may be either addresses of specific nodes, such as - # '--cluster.join 10.26.104.14,10.26.104.75', or a domain that resolves to - # the addresses of the nodes in the cluster (e.g. a Kubernetes headless - # service), such as '--cluster.join piko.prod-piko-ns'. - # - # Each address must include the host, and may optionally include a port. If no - # port is given, the gossip port of this node is used. - # - # Note each node propagates membership information to the other known nodes, - # so the initial set of configured members only needs to be a subset of nodes. - join: [] - - # Whether the server node should abort if it is configured with more than one - # node to join (excluding itself) but fails to join any members. - abort_if_join_fails: true +admin: + # The host/port to listen for incoming admin connections. + # + # If the host is unspecified it defaults to all listeners, such as + # '--admin.bind-addr :8002' will listen on '0.0.0.0:8002'. + bind_addr: ":8002" + + # Admin listen address to advertise to other nodes in the cluster. This is the + # address other nodes will used to forward admin requests. + # + # Such as if the listen address is ':8002', the advertised address may be + # '10.26.104.45:8002' or 'node1.cluster:8002'. + # + # By default, if the bind address includes an IP to bind to that will be used. + # If the bind address does not include an IP (such as ':8002') the nodes + # private IP will be used, such as a bind address of ':8002' may have an + # advertise address of '10.26.104.14:8002'. + advertise_addr: "" + + tls: + # Whether to enable TLS on the listener. + # + # If enabled must configure the cert and key. + enabled: false + + # Path to the PEM encoded certificate file. + cert: "" + + # Path to the PEM encoded key file. + key: "" auth: # Secret key to authenticate HMAC endpoint connection JWTs. @@ -234,8 +233,8 @@ log: # Maximum duration after a shutdown signal is received (SIGTERM or # SIGINT) to gracefully shutdown the server node before terminating. # This includes handling in-progress HTTP requests, gracefully closing -# connections to upstream listeners, announcing to the cluster the node is -# leaving... +# connections to upstream listeners and announcing to the cluster the node is +# leaving. grace_period: 1m0s ``` @@ -296,67 +295,41 @@ forwarded by Piko. The Piko agent is run using `piko agent`. It has the following configuration: ``` -# The endpoints to register with the Piko server. -# -# Each endpoint has an ID and forwarding address. -# -# At least one endpoint must be configured. -endpoints: - - id: endpoint-1 - addr: localhost:4000 - - id: endpoint-2 - addr: localhost:5000 +# Listeners contains the set of listeners to register. Each listener has an +# endpoint ID, address to forward connections to, whether to log each request +# and a timeout to forward requests to the upstreams. +listeners: + - endpoint_id: my-endpoint + addr: localhost:3000 + access_log: true + timeout: 15s + +connect: + # The Piko server URL to connect to. Note this must be configured to use the + # Piko server 'upstream' port. + url: http://localhost:8001 + + # Token is a token to authenticate with the Piko server. + token: "" + + # Timeout attempting to connect to the Piko server on boot. Note if the agent + # is disconnected after the initial connection succeeds it will keep trying to + # reconnect. + timeout: 30s + + tls: + # A path to a certificate PEM file containing root certificiate authorities to + # validate the TLS connection to the Piko server. + # + # Defaults to using the host root CAs. + root_cas: "" server: - # Piko server URL. - # - # The listener will add path /piko/v1/listener/:endpoint_id to the given URL, - # so if you include a path it will be used as a prefix. - # - # Note Piko connects to the server with WebSockets, so will replace http/https - # with ws/wss (you can configure either). - url: http://localhost:8001 + The host/port to bind the server to. - # Heartbeat interval. - # - # To verify the connection to the server is ok, the listener sends a - # heartbeat to the upstream at the '--server.heartbeat-interval' - # interval, with a timeout of '--server.heartbeat-timeout'.`, - heartbeat_interval: 10s - - # Heartbeat timeout. - # - # To verify the connection to the server is ok, the listener sends a - # heartbeat to the upstream at the '--server.heartbeat-interval' - heartbeat_timeout: 10s - -auth: - # An API key to authenticate the connection to Piko. - api_key: "" - -forwarder: - # Forwarder timeout. - # - # This is the timeout between a listener receiving a request from Piko then - # forwarding it to the configured forward address, and receiving a response. - # - # If the upstream does not respond within the given timeout a - # '504 Gateway Timeout' is returned to the client. - timeout: 10s - -tls: - # A path to a certificate PEM file containing root certificiate authorities - # to validate the TLS connection to the Piko server. - # - # Defaults to using the host root CAs.`, - root_cas: "" - -admin: - # The host/port to listen for incoming admin connections. - # - # If the host is unspecified it defaults to all listeners, such as - # '--admin.bind-addr :9000' will listen on '0.0.0.0:9000' - bind_addr: :9000 + If the host is unspecified it defaults to all listeners, such as + '--server.bind-addr :5000' will listen on '0.0.0.0:5000'. + bind_addr: ":5000" log: # Minimum log level to output. @@ -372,9 +345,13 @@ log: # # Such as you can enable 'gossip' logs with '--log.subsystems gossip'. subsystems: [] + +# Maximum duration after a shutdown signal is received (SIGTERM or +# SIGINT) to gracefully shutdown each listener. +grace_period: 1m0s ``` ### Authentication -To authenticate the agent, include a JWT in `auth.api_key`. The supported JWT +To authenticate the agent, include a JWT in `connect.token`. The supported JWT formats are described above in the server configuration. diff --git a/docs/manage/observability.md b/docs/manage/observability.md index 1a9334f..948c4d1 100644 --- a/docs/manage/observability.md +++ b/docs/manage/observability.md @@ -36,11 +36,11 @@ Piko also includes a number of Grafana dashboards at ## Status Piko includes a status CLI to inspect a Piko server. Servers register endpoints -at `/status` on the admin port that `piko status` then queries. +at `/status` on the admin port that `piko server status` then queries. Such as to view the endpoints registers on a server use -`piko status proxy endpoints`. Or to inspect the set of known nodes in the -cluster use `piko status cluster nodes`. +`piko server status proxy endpoints`. Or to inspect the set of known nodes in the +cluster use `piko server status cluster nodes`. Configure the server URL with `--server.url`. You can also forward the request to a particular node ID using `--forward` (which can be useful when all nodes diff --git a/docs/manage/overview.md b/docs/manage/overview.md index 1b7f459..393dc2d 100644 --- a/docs/manage/overview.md +++ b/docs/manage/overview.md @@ -56,7 +56,7 @@ See [Configure](./configure.md) for details. Each server node has an admin port (`8003` by default) which includes Prometheus metrics at `/metrics`, a health endpoint at `/health`, and a status API at `/status`. The status API exposes endpoints for inspecting the status of -a server node, which is used by the `piko status` CLI. +a server node, which is used by the `piko server status` CLI. See [Observability](./observability.md) for details. @@ -73,7 +73,7 @@ service. Such as you may configure the agent to register the endpoint `my-endpoint` then forward requests to `localhost:4000` using -`piko agent my-endpoint/localhost:4000`. +`piko agent http my-endpoint 4000`. See [Configure](./configure.md) for details. diff --git a/monitoring/dashboards/proxy.json b/monitoring/dashboards/proxy.json index 8afca31..bddf71d 100644 --- a/monitoring/dashboards/proxy.json +++ b/monitoring/dashboards/proxy.json @@ -15,7 +15,7 @@ "type": "grafana", "id": "grafana", "name": "Grafana", - "version": "10.4.3" + "version": "11.0.0" }, { "type": "datasource", @@ -152,6 +152,7 @@ "showLegend": true }, "tooltip": { + "maxHeight": 600, "mode": "single", "sort": "none" } @@ -247,6 +248,7 @@ "showLegend": true }, "tooltip": { + "maxHeight": 600, "mode": "single", "sort": "none" } @@ -342,6 +344,7 @@ "showLegend": true }, "tooltip": { + "maxHeight": 600, "mode": "single", "sort": "none" } @@ -463,6 +466,7 @@ "showLegend": true }, "tooltip": { + "maxHeight": 600, "mode": "single", "sort": "none" } @@ -584,6 +588,7 @@ "showLegend": true }, "tooltip": { + "maxHeight": 600, "mode": "single", "sort": "none" } @@ -718,6 +723,7 @@ "showLegend": true }, "tooltip": { + "maxHeight": 600, "mode": "single", "sort": "none" } @@ -729,7 +735,7 @@ "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", - "expr": "rate(piko_proxy_forwarded_local_total[$__rate_interval])", + "expr": "rate(piko_upstreams_upstream_requests_total[$__rate_interval])", "hide": false, "instant": false, "legendFormat": "{{instance}}", @@ -737,7 +743,7 @@ "refId": "B" } ], - "title": "Forwarded Local By Instance", + "title": "Upstream Requests By Instance", "type": "timeseries" }, { @@ -814,6 +820,7 @@ "showLegend": true }, "tooltip": { + "maxHeight": 600, "mode": "single", "sort": "none" } @@ -825,7 +832,7 @@ "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", - "expr": "sum by (instance) (rate(piko_proxy_forwarded_remote_total[$__rate_interval]))", + "expr": "sum by (instance) (rate(piko_upstreams_remote_requests_total[$__rate_interval]))", "hide": false, "instant": false, "legendFormat": "{{instance}}", @@ -833,7 +840,7 @@ "refId": "B" } ], - "title": "Forwarded Remote By Source", + "title": "Remote Requests By Source", "type": "timeseries" }, { @@ -910,6 +917,7 @@ "showLegend": true }, "tooltip": { + "maxHeight": 600, "mode": "single", "sort": "none" } @@ -921,7 +929,7 @@ "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", - "expr": "sum by (node_id) (rate(piko_proxy_forwarded_remote_total[$__rate_interval]))", + "expr": "sum by (node_id) (rate(piko_upstreams_remote_requests_total[$__rate_interval]))", "hide": false, "instant": false, "legendFormat": "{{instance}}", @@ -929,7 +937,7 @@ "refId": "B" } ], - "title": "Forwarded Remote By Target", + "title": "Remote Requests By Target", "type": "timeseries" }, { @@ -1019,6 +1027,7 @@ "showLegend": true }, "tooltip": { + "maxHeight": 600, "mode": "single", "sort": "none" } @@ -1030,7 +1039,7 @@ "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", - "expr": "piko_proxy_connected_upstreams", + "expr": "piko_upstreams_connected_upstreams", "hide": false, "instant": false, "legendFormat": "{{instance}}", @@ -1115,6 +1124,7 @@ "showLegend": true }, "tooltip": { + "maxHeight": 600, "mode": "single", "sort": "none" } @@ -1126,7 +1136,7 @@ "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", - "expr": "piko_proxy_connected_upstreams", + "expr": "piko_upstreams_registered_endpoints", "hide": false, "instant": false, "legendFormat": "{{instance}}", @@ -1134,7 +1144,7 @@ "refId": "B" } ], - "title": "Connected Upstreams By Instance", + "title": "Registered Endpoints By Instance", "type": "timeseries" } ], @@ -1150,6 +1160,7 @@ "from": "now-1h", "to": "now" }, + "timeRangeUpdatedDuringEditOrView": false, "timepicker": {}, "timezone": "browser", "title": "Piko Proxy", diff --git a/pkg/config/config.go b/pkg/config/config.go index 1279d94..029e687 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -6,19 +6,51 @@ import ( "os" "strings" + "github.com/spf13/pflag" "gopkg.in/yaml.v3" ) +type Config struct { + Path string `json:"path" yaml:"path"` + ExpandEnv bool `json:"expand_env" yaml:"expand_env"` +} + +func (c *Config) RegisterFlags(fs *pflag.FlagSet) { + fs.StringVar( + &c.Path, + "config.path", + "", + ` +YAML config file path.`, + ) + + fs.BoolVar( + &c.ExpandEnv, + "config.expand-env", + false, + ` +Whether to expand environment variables in the config file. + +This will replaces references to ${VAR} or $VAR with the corresponding +environment variable. The replacement is case-sensitive. + +References to undefined variables will be replaced with an empty string. A +default value can be given using form ${VAR:default}.`, + ) +} + // Load load the YAML configuration from the file at the given path. -// -// This will expand environment VARiables if expand is true. -func Load(path string, conf interface{}, expand bool) error { - buf, err := os.ReadFile(path) +func (c *Config) Load(conf interface{}) error { + if c.Path == "" { + return nil + } + + buf, err := os.ReadFile(c.Path) if err != nil { - return fmt.Errorf("read file: %s: %w", path, err) + return fmt.Errorf("read file: %s: %w", c.Path, err) } - if expand { + if c.ExpandEnv { buf = []byte(expandEnv(string(buf))) } @@ -26,7 +58,7 @@ func Load(path string, conf interface{}, expand bool) error { dec.KnownFields(true) if err := dec.Decode(conf); err != nil { - return fmt.Errorf("parse config: %s: %w", path, err) + return fmt.Errorf("parse config: %s: %w", c.Path, err) } return nil diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 05e636d..30b2227 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -29,7 +29,12 @@ sub: assert.NoError(t, err) var conf fakeConfig - assert.NoError(t, Load(f.Name(), &conf, false)) + + loadConfig := &Config{ + Path: f.Name(), + ExpandEnv: false, + } + assert.NoError(t, loadConfig.Load(&conf)) assert.Equal(t, "val1", conf.Foo) assert.Equal(t, "val2", conf.Bar) @@ -50,7 +55,12 @@ sub: assert.NoError(t, err) var conf fakeConfig - assert.NoError(t, Load(f.Name(), &conf, true)) + + loadConfig := &Config{ + Path: f.Name(), + ExpandEnv: true, + } + assert.NoError(t, loadConfig.Load(&conf)) assert.Equal(t, "val1", conf.Foo) assert.Equal(t, "val2", conf.Bar) @@ -65,11 +75,20 @@ sub: assert.NoError(t, err) var conf fakeConfig - assert.Error(t, Load(f.Name(), &conf, false)) + + loadConfig := &Config{ + Path: f.Name(), + ExpandEnv: false, + } + assert.Error(t, loadConfig.Load(&conf)) }) t.Run("not found", func(t *testing.T) { var conf fakeConfig - assert.Error(t, Load("notfound", &conf, false)) + loadConfig := &Config{ + Path: "/a/b/c/notfound", + ExpandEnv: false, + } + assert.Error(t, loadConfig.Load(&conf)) }) } diff --git a/pkg/conn/conn.go b/pkg/conn/conn.go deleted file mode 100644 index 2b830b5..0000000 --- a/pkg/conn/conn.go +++ /dev/null @@ -1,31 +0,0 @@ -package conn - -import "io" - -// RetryableError indicates a error is retryable. -type RetryableError struct { - err error -} - -func NewRetryableError(err error) *RetryableError { - return &RetryableError{err} -} - -func (e *RetryableError) Unwrap() error { - return e.err -} - -func (e *RetryableError) Error() string { - return e.err.Error() -} - -// Conn represents a bi-directional message-oriented connection between -// two peers. -type Conn interface { - ReadMessage() ([]byte, error) - NextReader() (io.Reader, error) - WriteMessage(b []byte) error - NextWriter() (io.WriteCloser, error) - Addr() string - Close() error -} diff --git a/pkg/conn/websocket/websocket.go b/pkg/conn/websocket/websocket.go deleted file mode 100644 index fb005b6..0000000 --- a/pkg/conn/websocket/websocket.go +++ /dev/null @@ -1,138 +0,0 @@ -package websocket - -import ( - "context" - "crypto/tls" - "fmt" - "io" - "net/http" - "time" - - "github.com/andydunstall/piko/pkg/conn" - "github.com/gorilla/websocket" -) - -// retryableStatusCodes contains a set of HTTP status codes that should be -// retried. -var retryableStatusCodes = map[int]struct{}{ - http.StatusRequestTimeout: {}, - http.StatusTooManyRequests: {}, - http.StatusInternalServerError: {}, - http.StatusBadGateway: {}, - http.StatusServiceUnavailable: {}, - http.StatusGatewayTimeout: {}, -} - -type options struct { - token string - tlsConfig *tls.Config -} - -type Option interface { - apply(*options) -} - -type tokenOption string - -func (o tokenOption) apply(opts *options) { - opts.token = string(o) -} - -func WithToken(token string) Option { - return tokenOption(token) -} - -type tlsConfigOption struct { - TLSConfig *tls.Config -} - -func (o tlsConfigOption) apply(opts *options) { - opts.tlsConfig = o.TLSConfig -} - -func WithTLSConfig(config *tls.Config) Option { - return tlsConfigOption{TLSConfig: config} -} - -type Conn struct { - wsConn *websocket.Conn -} - -func NewConn(wsConn *websocket.Conn) *Conn { - return &Conn{ - wsConn: wsConn, - } -} - -func Dial(ctx context.Context, url string, opts ...Option) (*Conn, error) { - options := options{} - for _, o := range opts { - o.apply(&options) - } - - dialer := &websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: 45 * time.Second, - } - - header := make(http.Header) - if options.token != "" { - header.Set("Authorization", "Bearer "+options.token) - } - - if options.tlsConfig != nil { - dialer.TLSClientConfig = options.tlsConfig - } - - wsConn, resp, err := dialer.DialContext( - ctx, url, header, - ) - if err != nil { - if resp != nil { - if _, ok := retryableStatusCodes[resp.StatusCode]; ok { - return nil, conn.NewRetryableError(err) - } - return nil, fmt.Errorf("%d: %w", resp.StatusCode, err) - } - return nil, conn.NewRetryableError(err) - } - return NewConn(wsConn), nil -} - -func (c *Conn) ReadMessage() ([]byte, error) { - mt, message, err := c.wsConn.ReadMessage() - if err != nil { - return nil, err - } - if mt != websocket.BinaryMessage { - return nil, fmt.Errorf("unexpected websocket message type: %d", mt) - } - return message, nil -} - -func (c *Conn) NextReader() (io.Reader, error) { - mt, r, err := c.wsConn.NextReader() - if err != nil { - return nil, err - } - if mt != websocket.BinaryMessage { - return nil, fmt.Errorf("unexpected websocket message type: %d", mt) - } - return r, nil -} - -func (c *Conn) WriteMessage(b []byte) error { - return c.wsConn.WriteMessage(websocket.BinaryMessage, b) -} - -func (c *Conn) NextWriter() (io.WriteCloser, error) { - return c.wsConn.NextWriter(websocket.BinaryMessage) -} - -func (c *Conn) Addr() string { - return c.wsConn.RemoteAddr().String() -} - -func (c *Conn) Close() error { - return c.wsConn.Close() -} diff --git a/pkg/forwarder/forwarder.go b/pkg/forwarder/forwarder.go deleted file mode 100644 index 1262efb..0000000 --- a/pkg/forwarder/forwarder.go +++ /dev/null @@ -1,35 +0,0 @@ -package forwarder - -import ( - "context" - "net/http" -) - -// Forwarder handles forwarding the given HTTP request to a server. -type Forwarder interface { - Request(ctx context.Context, addr string, r *http.Request) (*http.Response, error) -} - -type forwarder struct { - client *http.Client -} - -func NewForwarder() Forwarder { - return &forwarder{ - client: &http.Client{}, - } -} - -func (f *forwarder) Request( - ctx context.Context, - addr string, - r *http.Request, -) (*http.Response, error) { - r = r.WithContext(ctx) - - r.URL.Scheme = "http" - r.URL.Host = addr - r.RequestURI = "" - - return f.client.Do(r) -} diff --git a/pkg/gossip/state.go b/pkg/gossip/state.go index 8cd25d9..d9735e1 100644 --- a/pkg/gossip/state.go +++ b/pkg/gossip/state.go @@ -53,7 +53,7 @@ type NodeMetadata struct { // Expiry contains the time the node state will expire. This is only set // if the node is considered left or unreachable until the expiry. - Expiry time.Time + Expiry time.Time `json:"expiry"` } // NodeState contains the known state for the node. diff --git a/agentv2/endpoint/logger.go b/pkg/middleware/logger.go similarity index 84% rename from agentv2/endpoint/logger.go rename to pkg/middleware/logger.go index d56f5c1..01276d9 100644 --- a/agentv2/endpoint/logger.go +++ b/pkg/middleware/logger.go @@ -1,4 +1,4 @@ -package endpoint +package middleware import ( "net/http" @@ -20,9 +20,9 @@ type loggedRequest struct { Duration string `json:"duration"` } -// NewLoggerMiddleware creates logging middleware that logs every request. -func NewLoggerMiddleware(accessLog bool, logger log.Logger) gin.HandlerFunc { - logger = logger.WithSubsystem("endpoint.access") +// NewLogger creates logging middleware that logs every request. +func NewLogger(accessLog bool, logger log.Logger) gin.HandlerFunc { + logger = logger.WithSubsystem(logger.Subsystem() + ".access") return func(c *gin.Context) { s := time.Now() diff --git a/server/server/middleware/metrics.go b/pkg/middleware/metrics.go similarity index 98% rename from server/server/middleware/metrics.go rename to pkg/middleware/metrics.go index 953e9e3..185d5c3 100644 --- a/server/server/middleware/metrics.go +++ b/pkg/middleware/metrics.go @@ -17,7 +17,6 @@ type Metrics struct { ResponseSize prometheus.Histogram } -// NewMetrics creates metrics middleware. func NewMetrics(subsystem string) *Metrics { sizeBuckets := prometheus.ExponentialBuckets(256, 4, 8) return &Metrics{ diff --git a/pkg/mux/session.go b/pkg/mux/session.go index 010c3c5..6e077fb 100644 --- a/pkg/mux/session.go +++ b/pkg/mux/session.go @@ -3,7 +3,6 @@ package mux import ( "io" "net" - "time" "golang.ngrok.com/muxado/v2" ) @@ -11,34 +10,30 @@ import ( // Session is a connection between two nodes that multiplexes multiple // connections on the underlying connection. // -// The session also has heartbeats to verify the underlying connection is -// healthy. -// // Session is a wrapper for the 'muxado' library. type Session struct { - mux *muxado.Heartbeat + mux muxado.Session } func OpenClient(conn io.ReadWriteCloser) *Session { - sess := &Session{} - - mux := muxado.NewHeartbeat( - muxado.NewTypedStreamSession( - muxado.Client(conn, &muxado.Config{}), - ), - sess.onHeartbeat, - muxado.NewHeartbeatConfig(), - ) - mux.Start() + return &Session{ + mux: muxado.Client(conn, &muxado.Config{}), + } +} - sess.mux = mux +func OpenServer(conn io.ReadWriteCloser) *Session { + return &Session{ + mux: muxado.Server(conn, &muxado.Config{}), + } +} - return sess +func (s *Session) Dial() (net.Conn, error) { + return s.mux.OpenStream() } // Accept accepts a multiplexed connection. func (s *Session) Accept() (net.Conn, error) { - conn, err := s.mux.AcceptTypedStream() + conn, err := s.mux.AcceptStream() if err != nil { muxadoErr, _ := muxado.GetError(err) if muxadoErr == muxado.SessionClosed { @@ -49,12 +44,11 @@ func (s *Session) Accept() (net.Conn, error) { return conn, nil } -func (s *Session) Close() error { - return s.mux.Close() +func (s *Session) Wait() error { + err, _, _ := s.mux.Wait() + return err } -func (s *Session) onHeartbeat(_ time.Duration, timeout bool) { - if timeout { - s.mux.Close() - } +func (s *Session) Close() error { + return s.mux.Close() } diff --git a/pkg/rpc/handler.go b/pkg/rpc/handler.go deleted file mode 100644 index 78bf23f..0000000 --- a/pkg/rpc/handler.go +++ /dev/null @@ -1,35 +0,0 @@ -package rpc - -import "sync" - -// HandlerFunc handles the given request message and returns a response. -type HandlerFunc func(message []byte) []byte - -// Handler is responsible for registering RPC request handlers for RPC types. -type Handler struct { - handlers map[Type]HandlerFunc - mu sync.Mutex -} - -func NewHandler() *Handler { - return &Handler{ - handlers: make(map[Type]HandlerFunc), - } -} - -// Register adds a new handler for the given RPC request type. -func (h *Handler) Register(rpcType Type, handler HandlerFunc) { - h.mu.Lock() - defer h.mu.Unlock() - - h.handlers[rpcType] = handler -} - -// Find looks up the handler for the given RPC type. -func (h *Handler) Find(rpcType Type) (HandlerFunc, bool) { - h.mu.Lock() - defer h.mu.Unlock() - - handler, ok := h.handlers[rpcType] - return handler, ok -} diff --git a/pkg/rpc/header.go b/pkg/rpc/header.go deleted file mode 100644 index e51e304..0000000 --- a/pkg/rpc/header.go +++ /dev/null @@ -1,76 +0,0 @@ -package rpc - -import ( - "encoding/binary" - "fmt" -) - -const ( - headerSize = 12 - - bit1 flags = 1 << 15 - bit2 flags = 1 << 14 -) - -// flags is a bitset of message flags. -// -// From high order bit down, flags contains: -// - Request/Response: 0 if the message is a request, 1 if the message is a -// response -// - Not supported: 1 if the no handler for the requested RPC type is found -type flags uint16 - -// Response returns true if the message is a response, false if the message is -// a request. -func (f *flags) Response() bool { - return *f&bit1 != 0 -} - -func (f *flags) SetResponse() { - if f.Response() { - return - } - *f |= bit1 -} - -// ErrNotSupported returns true if the requested RPC type was not supported by -// the receiver. -func (f *flags) ErrNotSupported() bool { - return *f&bit2 != 0 -} - -func (f *flags) SetErrNotSupported() { - if f.ErrNotSupported() { - return - } - *f |= bit2 -} - -type header struct { - // RPCType contains the application RPC type, such as 'heartbeat'. - RPCType Type - - // ID uniquely identifies the request/response pair. - ID uint64 - - // Flags contains a bitset of flags. - Flags flags -} - -func (h *header) Encode() []byte { - b := make([]byte, headerSize) - binary.BigEndian.PutUint16(b, uint16(h.RPCType)) - binary.BigEndian.PutUint64(b[2:], h.ID) - binary.BigEndian.PutUint16(b[10:], uint16(h.Flags)) - return b -} - -func (h *header) Decode(b []byte) error { - if len(b) < headerSize { - return fmt.Errorf("message too small: %d", len(b)) - } - h.RPCType = Type(binary.BigEndian.Uint16(b)) - h.ID = binary.BigEndian.Uint64(b[2:]) - h.Flags = flags(binary.BigEndian.Uint16(b[10:])) - return nil -} diff --git a/pkg/rpc/header_test.go b/pkg/rpc/header_test.go deleted file mode 100644 index 098c8d9..0000000 --- a/pkg/rpc/header_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package rpc - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestFlags_Response(t *testing.T) { - var f flags - assert.False(t, f.Response()) - f.SetResponse() - assert.True(t, f.Response()) -} - -func TestFlags_ErrNotSupported(t *testing.T) { - var f flags - assert.False(t, f.ErrNotSupported()) - f.SetErrNotSupported() - assert.True(t, f.ErrNotSupported()) -} - -func TestHeader_Encode(t *testing.T) { - var flags flags - flags.SetResponse() - h := header{ - RPCType: Type(0xff), - ID: 0x012345678, - Flags: flags, - } - assert.Equal(t, []byte{0x0, 0xff, 0x0, 0x0, 0x0, 0x0, 0x12, 0x34, 0x56, 0x78, 0x80, 0x0}, h.Encode()) -} - -func TestHeader_Decode(t *testing.T) { - var flags flags - flags.SetResponse() - h1 := header{ - RPCType: Type(0xff), - ID: 0x012345678, - Flags: flags, - } - var h2 header - assert.NoError(t, h2.Decode(h1.Encode())) - assert.Equal(t, h1, h2) - - assert.Error(t, h2.Decode([]byte("xxx"))) -} diff --git a/pkg/rpc/stream.go b/pkg/rpc/stream.go deleted file mode 100644 index 69e0292..0000000 --- a/pkg/rpc/stream.go +++ /dev/null @@ -1,367 +0,0 @@ -package rpc - -import ( - "context" - "errors" - "fmt" - "sync" - "time" - - "github.com/andydunstall/piko/pkg/conn" - "github.com/andydunstall/piko/pkg/log" - "go.uber.org/atomic" - "go.uber.org/zap" -) - -var ( - ErrStreamClosed = errors.New("stream closed") -) - -type message struct { - Header *header - Payload []byte -} - -// Stream represents a bi-directional RPC stream between two peers. Either peer -// can send an RPC request to the other. -// -// The stream uses the underlying bi-directional connection to send RPC -// requests, and multiplexes multiple concurrent request/response RPCs on the -// same connection. -// -// Incoming RPC requests are handled in their own goroutine to avoid blocking -// the stream. -type Stream interface { - Addr() string - RPC(ctx context.Context, rpcType Type, req []byte) ([]byte, error) - Monitor( - ctx context.Context, - interval time.Duration, - timeout time.Duration, - ) error - Close() error -} - -type stream struct { - conn conn.Conn - handler *Handler - - // nextMessageID is the ID of the next RPC message to send. - nextMessageID *atomic.Uint64 - - writeCh chan *message - - // responseHandlers contains channels for RPC responses. - responseHandlers map[uint64]chan<- *message - responseHandlersMu sync.Mutex - - // shutdownCh is closed when the stream is shutdown. - shutdownCh chan struct{} - // shutdownErr is the first error that caused the stream to shutdown. - shutdownErr error - // shutdown indicates whether the stream is already shutdown. - shutdown *atomic.Bool - - logger log.Logger -} - -// NewStream creates an RPC stream on top of the given message-oriented -// connection. -func NewStream(conn conn.Conn, handler *Handler, logger log.Logger) Stream { - stream := &stream{ - conn: conn, - handler: handler, - nextMessageID: atomic.NewUint64(0), - writeCh: make(chan *message, 64), - responseHandlers: make(map[uint64]chan<- *message), - shutdownCh: make(chan struct{}), - shutdown: atomic.NewBool(false), - logger: logger.WithSubsystem("rpc"), - } - go stream.reader() - go stream.writer() - - return stream -} - -func (s *stream) Addr() string { - return s.conn.Addr() -} - -// RPC sends the given request message to the peer and returns the response or -// an error. -// -// RPC is thread safe. -func (s *stream) RPC(ctx context.Context, rpcType Type, req []byte) ([]byte, error) { - header := &header{ - RPCType: rpcType, - ID: s.nextMessageID.Inc(), - } - msg := &message{ - Header: header, - Payload: req, - } - - ch := make(chan *message, 1) - s.registerResponseHandler(header.ID, ch) - defer s.unregisterResponseHandler(header.ID) - - select { - case s.writeCh <- msg: - case <-s.shutdownCh: - return nil, s.shutdownErr - case <-ctx.Done(): - return nil, ctx.Err() - } - - select { - case resp := <-ch: - if resp.Header.Flags.ErrNotSupported() { - return nil, fmt.Errorf("not supported") - } - return resp.Payload, nil - case <-s.shutdownCh: - return nil, s.shutdownErr - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -// Monitor monitors the stream is healthy using heartbeats. -func (s *stream) Monitor( - ctx context.Context, - interval time.Duration, - timeout time.Duration, -) error { - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - if err := s.heartbeat(ctx, timeout); err != nil { - return fmt.Errorf("heartbeat: %w", err) - - } - - select { - case <-ctx.Done(): - return ctx.Err() - case <-s.shutdownCh: - return s.shutdownErr - case <-ticker.C: - } - } -} - -func (s *stream) Close() error { - return s.closeStream(ErrStreamClosed) -} - -func (s *stream) reader() { - defer s.recoverPanic("reader()") - - for { - b, err := s.conn.ReadMessage() - if err != nil { - _ = s.closeStream(fmt.Errorf("read: %w", err)) - return - } - - var header header - if err = header.Decode(b); err != nil { - _ = s.closeStream(fmt.Errorf("decode header: %w", err)) - return - } - payload := b[headerSize:] - - s.logger.Debug( - "message received", - zap.String("type", header.RPCType.String()), - zap.Bool("response", header.Flags.Response()), - zap.Uint64("message_id", header.ID), - zap.Int("len", len(payload)), - ) - - if header.Flags.Response() { - s.handleResponse(&message{ - Header: &header, - Payload: payload, - }) - } else { - // Spawn a new goroutine for each request to avoid blocking - // the read loop. - go s.handleRequest(&message{ - Header: &header, - Payload: payload, - }) - } - - select { - case <-s.shutdownCh: - return - default: - } - } -} - -func (s *stream) writer() { - defer s.recoverPanic("writer()") - - for { - select { - case req := <-s.writeCh: - if err := s.write(req); err != nil { - _ = s.closeStream(fmt.Errorf("write: %w", err)) - return - } - - s.logger.Debug( - "message sent", - zap.String("type", req.Header.RPCType.String()), - zap.Bool("response", req.Header.Flags.Response()), - zap.Uint64("message_id", req.Header.ID), - zap.Int("len", len(req.Payload)), - ) - case <-s.shutdownCh: - return - } - } -} - -func (s *stream) write(req *message) error { - w, err := s.conn.NextWriter() - if err != nil { - return err - } - if _, err = w.Write(req.Header.Encode()); err != nil { - return err - } - if len(req.Payload) > 0 { - if _, err = w.Write(req.Payload); err != nil { - return err - } - } - return w.Close() -} - -func (s *stream) closeStream(err error) error { - // Only shutdown once. - if !s.shutdown.CompareAndSwap(false, true) { - return ErrStreamClosed - } - - s.shutdownErr = ErrStreamClosed - // Close to cancel pending RPC requests. - close(s.shutdownCh) - - if err := s.conn.Close(); err != nil { - return fmt.Errorf("close conn: %w", err) - } - - s.logger.Debug( - "stream closed", - zap.Error(err), - ) - - return nil -} - -func (s *stream) handleRequest(m *message) { - handlerFunc, ok := s.handler.Find(m.Header.RPCType) - if !ok { - // If no handler is found, send a 'not supported' error to the client. - s.logger.Warn( - "rpc type not supported", - zap.String("type", m.Header.RPCType.String()), - zap.Uint64("message_id", m.Header.ID), - ) - - var flags flags - flags.SetResponse() - flags.SetErrNotSupported() - msg := &message{ - Header: &header{ - RPCType: m.Header.RPCType, - ID: m.Header.ID, - Flags: flags, - }, - } - select { - case s.writeCh <- msg: - return - case <-s.shutdownCh: - return - } - } - - resp := handlerFunc(m.Payload) - - var flags flags - flags.SetResponse() - msg := &message{ - Header: &header{ - RPCType: m.Header.RPCType, - ID: m.Header.ID, - Flags: flags, - }, - Payload: resp, - } - - select { - case s.writeCh <- msg: - return - case <-s.shutdownCh: - return - } -} - -func (s *stream) handleResponse(m *message) { - // If no handler is found, it means RPC has already returned so discard - // the response. - ch, ok := s.findResponseHandler(m.Header.ID) - if ok { - ch <- m - } -} - -func (s *stream) recoverPanic(prefix string) { - if r := recover(); r != nil { - _ = s.closeStream(fmt.Errorf("panic: %s: %v", prefix, r)) - } -} - -func (s *stream) registerResponseHandler(id uint64, ch chan<- *message) { - s.responseHandlersMu.Lock() - defer s.responseHandlersMu.Unlock() - - s.responseHandlers[id] = ch -} - -func (s *stream) unregisterResponseHandler(id uint64) { - s.responseHandlersMu.Lock() - defer s.responseHandlersMu.Unlock() - - delete(s.responseHandlers, id) -} - -func (s *stream) findResponseHandler(id uint64) (chan<- *message, bool) { - s.responseHandlersMu.Lock() - defer s.responseHandlersMu.Unlock() - - ch, ok := s.responseHandlers[id] - return ch, ok -} - -func (s *stream) heartbeat(ctx context.Context, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - ts := time.Now() - _, err := s.RPC(ctx, TypeHeartbeat, nil) - if err != nil { - return fmt.Errorf("rpc: %w", err) - } - - s.logger.Debug("heartbeat ok", zap.Duration("rtt", time.Since(ts))) - - return nil -} diff --git a/pkg/rpc/type.go b/pkg/rpc/type.go deleted file mode 100644 index 76142b8..0000000 --- a/pkg/rpc/type.go +++ /dev/null @@ -1,23 +0,0 @@ -package rpc - -// Type is an identifier for the RPC request/response type. -type Type uint16 - -const ( - // TypeHeartbeat sends health checks between peers. - TypeHeartbeat Type = iota + 1 - // TypeProxyHTTP sends a HTTP request and response between the Piko server - // and an upstream listener. - TypeProxyHTTP -) - -func (t *Type) String() string { - switch *t { - case TypeHeartbeat: - return "heartbeat" - case TypeProxyHTTP: - return "proxy-http" - default: - return "unknown" - } -} diff --git a/server/admin/server.go b/server/admin/server.go new file mode 100644 index 0000000..d53cdd1 --- /dev/null +++ b/server/admin/server.go @@ -0,0 +1,122 @@ +package admin + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + + "github.com/andydunstall/piko/pkg/log" + "github.com/andydunstall/piko/server/status" + "github.com/gin-gonic/gin" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// Server is the admin HTTP server, which exposes endpoints for metrics, health +// and inspecting the node status. +type Server struct { + registry *prometheus.Registry + + httpServer *http.Server + + router *gin.Engine + + logger log.Logger +} + +func NewServer( + registry *prometheus.Registry, + tlsConfig *tls.Config, + logger log.Logger, +) *Server { + logger = logger.WithSubsystem("admin") + + router := gin.New() + server := &Server{ + registry: registry, + httpServer: &http.Server{ + Handler: router, + TLSConfig: tlsConfig, + ErrorLog: logger.StdLogger(zapcore.WarnLevel), + }, + router: router, + logger: logger, + } + + // Recover from panics. + router.Use(gin.CustomRecoveryWithWriter(nil, server.panicRoute)) + + server.registerRoutes(router) + + return server +} + +func (s *Server) Serve(ln net.Listener) error { + s.logger.Info( + "starting admin server", + zap.String("addr", ln.Addr().String()), + ) + + var err error + if s.httpServer.TLSConfig != nil { + err = s.httpServer.ServeTLS(ln, "", "") + } else { + err = s.httpServer.Serve(ln) + } + + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("http serve: %w", err) + } + return nil +} + +// Shutdown attempts to gracefully shutdown the server by waiting for pending +// requests to complete. +func (s *Server) Shutdown(ctx context.Context) error { + return s.httpServer.Shutdown(ctx) +} + +func (s *Server) AddStatus(route string, handler status.Handler) { + group := s.router.Group("/status").Group(route) + handler.Register(group) +} + +func (s *Server) registerRoutes(router *gin.Engine) { + router.GET("/health", s.healthRoute) + + if s.registry != nil { + router.GET("/metrics", s.metricsHandler()) + } +} + +func (s *Server) healthRoute(c *gin.Context) { + c.Status(http.StatusOK) +} + +func (s *Server) panicRoute(c *gin.Context, err any) { + s.logger.Error( + "handler panic", + zap.String("path", c.FullPath()), + zap.Any("err", err), + ) + c.AbortWithStatus(http.StatusInternalServerError) +} + +func (s *Server) metricsHandler() gin.HandlerFunc { + h := promhttp.HandlerFor( + s.registry, + promhttp.HandlerOpts{Registry: s.registry}, + ) + return func(c *gin.Context) { + h.ServeHTTP(c.Writer, c.Request) + } +} + +func init() { + // Disable Gin debug logs. + gin.SetMode(gin.ReleaseMode) +} diff --git a/server/admin/server_test.go b/server/admin/server_test.go new file mode 100644 index 0000000..a6b2f88 --- /dev/null +++ b/server/admin/server_test.go @@ -0,0 +1,164 @@ +package admin + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "testing" + + "github.com/andydunstall/piko/pkg/log" + "github.com/andydunstall/piko/pkg/testutil" + "github.com/andydunstall/piko/server/status" + "github.com/gin-gonic/gin" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type fakeStatus struct { +} + +func (s *fakeStatus) Register(group *gin.RouterGroup) { + group.GET("/foo", s.fooRoute) +} + +func (s *fakeStatus) fooRoute(c *gin.Context) { + c.String(http.StatusOK, "foo") +} + +var _ status.Handler = &fakeStatus{} + +func TestServer_AdminRoutes(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + s := NewServer( + prometheus.NewRegistry(), + nil, + log.NewNopLogger(), + ) + go func() { + require.NoError(t, s.Serve(ln)) + }() + defer s.Shutdown(context.TODO()) + + t.Run("metrics", func(t *testing.T) { + url := fmt.Sprintf("http://%s/metrics", ln.Addr().String()) + resp, err := http.Get(url) + assert.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("not found", func(t *testing.T) { + url := fmt.Sprintf("http://%s/foo", ln.Addr().String()) + resp, err := http.Get(url) + assert.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + }) +} + +func TestServer_StatusRoutes(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + s := NewServer( + prometheus.NewRegistry(), + nil, + log.NewNopLogger(), + ) + s.AddStatus("/mystatus", &fakeStatus{}) + + go func() { + require.NoError(t, s.Serve(ln)) + }() + defer s.Shutdown(context.TODO()) + + t.Run("status ok", func(t *testing.T) { + url := fmt.Sprintf("http://%s/status/mystatus/foo", ln.Addr().String()) + resp, err := http.Get(url) + assert.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + buf := new(bytes.Buffer) + //nolint + buf.ReadFrom(resp.Body) + assert.Equal(t, []byte("foo"), buf.Bytes()) + }) + + t.Run("not found", func(t *testing.T) { + url := fmt.Sprintf("http://%s/status/notfound", ln.Addr().String()) + resp, err := http.Get(url) + assert.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + }) +} + +func TestServer_TLS(t *testing.T) { + rootCAPool, cert, err := testutil.LocalTLSServerCert() + require.NoError(t, err) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + tlsConfig := &tls.Config{} + tlsConfig.Certificates = []tls.Certificate{cert} + + s := NewServer( + prometheus.NewRegistry(), + tlsConfig, + log.NewNopLogger(), + ) + go func() { + require.NoError(t, s.Serve(ln)) + }() + defer s.Shutdown(context.TODO()) + + t.Run("https ok", func(t *testing.T) { + tlsConfig = &tls.Config{ + RootCAs: rootCAPool, + } + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + } + client := &http.Client{ + Transport: transport, + } + + req, _ := http.NewRequest( + http.MethodGet, + fmt.Sprintf("https://%s/health", ln.Addr().String()), + nil, + ) + resp, err := client.Do(req) + assert.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("https bad ca", func(t *testing.T) { + url := fmt.Sprintf("https://%s/health", ln.Addr().String()) + _, err := http.Get(url) + assert.ErrorContains(t, err, "certificate signed by unknown authority") + }) + + t.Run("http", func(t *testing.T) { + url := fmt.Sprintf("http://%s/health", ln.Addr().String()) + resp, err := http.Get(url) + assert.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + }) +} diff --git a/server/config/config.go b/server/config/config.go index 99ad23b..c7a4224 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -10,147 +10,292 @@ import ( "github.com/spf13/pflag" ) -type UpstreamConfig struct { - // BindAddr is the address to bind to listen for incoming HTTP connections. - BindAddr string `json:"bind_addr" yaml:"bind_addr"` +type ClusterConfig struct { + // NodeID is a unique identifier for this node in the cluster. + NodeID string `json:"node_id" yaml:"node_id"` - // AdvertiseAddr is the address to advertise to other nodes. - AdvertiseAddr string `json:"advertise_addr" yaml:"advertise_addr"` + // NodeIDPrefix is a node ID prefix, where Piko will generate the rest of + // the node ID to ensure uniqueness. + NodeIDPrefix string `json:"node_id_prefix" yaml:"node_id_prefix"` - TLS TLSConfig `json:"tls" yaml:"tls"` + // Join contians a list of addresses of members in the cluster to join. + Join []string `json:"join" yaml:"join"` + + AbortIfJoinFails bool `json:"abort_if_join_fails" yaml:"abort_if_join_fails"` } -func (c *UpstreamConfig) Validate() error { - if c.BindAddr == "" { - return fmt.Errorf("missing bind addr") - } - if err := c.TLS.Validate(); err != nil { - return fmt.Errorf("tls: %w", err) +func (c *ClusterConfig) Validate() error { + if c.NodeID == "" { + return fmt.Errorf("missing node id") } + return nil } -type AdminConfig struct { +func (c *ClusterConfig) RegisterFlags(fs *pflag.FlagSet) { + fs.StringVar( + &c.NodeID, + "cluster.node-id", + "", + ` +A unique identifier for the node in the cluster. + +By default a random ID will be generated for the node.`, + ) + + fs.StringVar( + &c.NodeIDPrefix, + "cluster.node-id-prefix", + "", + ` +A prefix for the node ID. + +Piko will generate a unique random identifier for the node and append it to +the given prefix. + +Such as you could use the node or pod name as a prefix, then add a unique +identifier to ensure the node ID is unique across restarts.`, + ) + + fs.StringSliceVar( + &c.Join, + "cluster.join", + nil, + ` +A list of addresses of members in the cluster to join. + +This may be either addresses of specific nodes, such as +'--cluster.join 10.26.104.14,10.26.104.75', or a domain that resolves to +the addresses of the nodes in the cluster (e.g. a Kubernetes headless +service), such as '--cluster.join piko.prod-piko-ns'. + +Each address must include the host, and may optionally include a port. If no +port is given, the gossip port of this node is used. + +Note each node propagates membership information to the other known nodes, +so the initial set of configured members only needs to be a subset of nodes.`, + ) + + fs.BoolVar( + &c.AbortIfJoinFails, + "cluster.abort-if-join-fails", + true, + ` +Whether the server node should abort if it is configured with more than one +node to join (excluding itself) but fails to join any members.`, + ) +} + +// HTTPConfig contains generic configuration for the HTTP servers. +type HTTPConfig struct { + // ReadTimeout is the maximum duration for reading the entire + // request, including the body. A zero or negative value means + // there will be no timeout. + ReadTimeout time.Duration `json:"read_timeout" yaml:"read_timeout"` + + // ReadHeaderTimeout is the amount of time allowed to read + // request headers. + ReadHeaderTimeout time.Duration `json:"read_header_timeout" yaml:"read_header_timeout"` + + // WriteTimeout is the maximum duration before timing out + // writes of the response. + WriteTimeout time.Duration `json:"write_timeout" yaml:"write_timeout"` + + // IdleTimeout is the maximum amount of time to wait for the + // next request when keep-alives are enabled. + IdleTimeout time.Duration `json:"idle_timeout" yaml:"idle_timeout"` + + // MaxHeaderBytes controls the maximum number of bytes the + // server will read parsing the request header's keys and + // values, including the request line. + MaxHeaderBytes int `json:"max_header_bytes" yaml:"max_header_bytes"` +} + +func (c *HTTPConfig) RegisterFlags(fs *pflag.FlagSet, prefix string) { + if prefix == "" { + prefix = "http." + } else { + prefix = prefix + ".http." + } + + fs.DurationVar( + &c.ReadTimeout, + prefix+"read-timeout", + time.Second*10, + ` +The maximum duration for reading the entire request, including the body. A +zero or negative value means there will be no timeout.`, + ) + fs.DurationVar( + &c.ReadHeaderTimeout, + prefix+"read-header-timeout", + time.Second*10, + ` +The maximum duration for reading the request headers. If zero, +http.read-timeout is used.`, + ) + fs.DurationVar( + &c.WriteTimeout, + prefix+"write-timeout", + time.Second*10, + ` +The maximum duration before timing out writes of the response.`, + ) + fs.DurationVar( + &c.IdleTimeout, + prefix+"idle-timeout", + time.Minute*5, + ` +The maximum amount of time to wait for the next request when keep-alives are +enabled.`, + ) + fs.IntVar( + &c.MaxHeaderBytes, + prefix+"max-header-bytes", + 1<<20, + ` +The maximum number of bytes the server will read parsing the request header's +keys and values, including the request line.`, + ) +} + +type ProxyConfig struct { // BindAddr is the address to bind to listen for incoming HTTP connections. BindAddr string `json:"bind_addr" yaml:"bind_addr"` // AdvertiseAddr is the address to advertise to other nodes. AdvertiseAddr string `json:"advertise_addr" yaml:"advertise_addr"` + // Timeout is the timeout to forward incoming requests to the upstream. + Timeout time.Duration `json:"timeout" yaml:"timeout"` + + // AccessLog indicates whether to log all incoming connections and + // requests. + AccessLog bool `json:"access_log" yaml:"access_log"` + + HTTP HTTPConfig `json:"http" yaml:"http"` + TLS TLSConfig `json:"tls" yaml:"tls"` } -func (c *AdminConfig) Validate() error { +func (c *ProxyConfig) Validate() error { if c.BindAddr == "" { return fmt.Errorf("missing bind addr") } + if c.Timeout == 0 { + return fmt.Errorf("missing timeout") + } if err := c.TLS.Validate(); err != nil { return fmt.Errorf("tls: %w", err) } return nil } -type ClusterConfig struct { - // NodeID is a unique identifier for this node in the cluster. - NodeID string `json:"node_id" yaml:"node_id"` +func (c *ProxyConfig) RegisterFlags(fs *pflag.FlagSet) { + fs.StringVar( + &c.BindAddr, + "proxy.bind-addr", + ":8000", + ` +The host/port to listen for incoming proxy connections. - // NodeIDPrefix is a node ID prefix, where Piko will generate the rest of - // the node ID to ensure uniqueness. - NodeIDPrefix string `json:"node_id_prefix" yaml:"node_id_prefix"` +If the host is unspecified it defaults to all listeners, such as +'--proxy.bind-addr :8000' will listen on '0.0.0.0:8000'`, + ) - // Join contians a list of addresses of members in the cluster to join. - Join []string `json:"join" yaml:"join"` + fs.StringVar( + &c.AdvertiseAddr, + "proxy.advertise-addr", + "", + ` +Proxy to advertise to other nodes in the cluster. This is the +address other nodes will used to forward proxy connections. - AbortIfJoinFails bool `json:"abort_if_join_fails" yaml:"abort_if_join_fails"` -} +Such as if the listen address is ':8000', the advertised address may be +'10.26.104.45:8000' or 'node1.cluster:8000'. -func (c *ClusterConfig) Validate() error { - if c.NodeID != "" && c.NodeIDPrefix != "" { - return fmt.Errorf("cannot specify both node ID and node ID prefix") - } - return nil -} +By default, if the bind address includes an IP to bind to that will be used. +If the bind address does not include an IP (such as ':8000') the nodes +private IP will be used, such as a bind address of ':8000' may have an +advertise address of '10.26.104.14:8000'.`, + ) + + fs.DurationVar( + &c.Timeout, + "proxy.timeout", + time.Second*30, + ` +Timeout when forwarding incoming requests to the upstream.`, + ) -type UsageConfig struct { - // Disable indicates whether to disable anonymous usage collection. - Disable bool `json:"disable" yaml:"disable"` + fs.BoolVar( + &c.AccessLog, + "proxy.access-log", + true, + ` +Whether to log all incoming connections and requests.`, + ) + + c.HTTP.RegisterFlags(fs, "proxy") + + c.TLS.RegisterFlags(fs, "proxy") } -type Config struct { - Proxy ProxyConfig `json:"proxy" yaml:"proxy"` - Upstream UpstreamConfig `json:"upstream" yaml:"upstream"` - Admin AdminConfig `json:"admin" yaml:"admin"` - Gossip gossip.Config `json:"gossip" yaml:"gossip"` - Cluster ClusterConfig `json:"cluster" yaml:"cluster"` - Auth auth.Config `json:"auth" yaml:"auth"` - Usage UsageConfig `json:"usage" yaml:"usage"` - Log log.Config `json:"log" yaml:"log"` +type UpstreamConfig struct { + // BindAddr is the address to bind to listen for incoming HTTP connections. + BindAddr string `json:"bind_addr" yaml:"bind_addr"` - // GracePeriod is the duration to gracefully shutdown the server. During - // the grace period, listeners and idle connections are closed, then waits - // for active requests to complete and closes their connections. - GracePeriod time.Duration `json:"grace_period" yaml:"grace_period"` + TLS TLSConfig `json:"tls" yaml:"tls"` } -func (c *Config) Validate() error { - if err := c.Proxy.Validate(); err != nil { - return fmt.Errorf("proxy: %w", err) - } - if err := c.Upstream.Validate(); err != nil { - return fmt.Errorf("upstream: %w", err) - } - if err := c.Admin.Validate(); err != nil { - return fmt.Errorf("admin: %w", err) - } - if err := c.Gossip.Validate(); err != nil { - return fmt.Errorf("gossip: %w", err) - } - if err := c.Cluster.Validate(); err != nil { - return fmt.Errorf("cluster: %w", err) - } - if err := c.Log.Validate(); err != nil { - return fmt.Errorf("log: %w", err) +func (c *UpstreamConfig) Validate() error { + if c.BindAddr == "" { + return fmt.Errorf("missing bind addr") } - - if c.GracePeriod == 0 { - return fmt.Errorf("missing grace period") + if err := c.TLS.Validate(); err != nil { + return fmt.Errorf("tls: %w", err) } - return nil } -func (c *Config) RegisterFlags(fs *pflag.FlagSet) { - c.Proxy.RegisterFlags(fs) - +func (c *UpstreamConfig) RegisterFlags(fs *pflag.FlagSet) { fs.StringVar( - &c.Upstream.BindAddr, + &c.BindAddr, "upstream.bind-addr", ":8001", ` -The host/port to listen for connections from upstream listeners. +The host/port to listen for incoming upstream connections. If the host is unspecified it defaults to all listeners, such as '--upstream.bind-addr :8001' will listen on '0.0.0.0:8001'`, ) - fs.StringVar( - &c.Upstream.AdvertiseAddr, - "upstream.advertise-addr", - "", - ` -Upstream listen address to advertise to other nodes in the cluster. -Such as if the listen address is ':8001', the advertised address may be -'10.26.104.45:8001' or 'node1.cluster:8001'. + c.TLS.RegisterFlags(fs, "upstream") +} -By default, if the bind address includes an IP to bind to that will be used. -If the bind address does not include an IP (such as ':8001') the nodes -private IP will be used, such as a bind address of ':8001' may have an -advertise address of '10.16.104.14:8001'.`, - ) - c.Upstream.TLS.RegisterFlags(fs, "upstream") +type AdminConfig struct { + // BindAddr is the address to bind to listen for incoming HTTP connections. + BindAddr string `json:"bind_addr" yaml:"bind_addr"` + + // AdvertiseAddr is the address to advertise to other nodes. + AdvertiseAddr string `json:"advertise_addr" yaml:"advertise_addr"` + + TLS TLSConfig `json:"tls" yaml:"tls"` +} + +func (c *AdminConfig) Validate() error { + if c.BindAddr == "" { + return fmt.Errorf("missing bind addr") + } + if err := c.TLS.Validate(); err != nil { + return fmt.Errorf("tls: %w", err) + } + return nil +} +func (c *AdminConfig) RegisterFlags(fs *pflag.FlagSet) { fs.StringVar( - &c.Admin.BindAddr, + &c.BindAddr, "admin.bind-addr", ":8002", ` @@ -159,8 +304,9 @@ The host/port to listen for incoming admin connections. If the host is unspecified it defaults to all listeners, such as '--admin.bind-addr :8002' will listen on '0.0.0.0:8002'`, ) + fs.StringVar( - &c.Admin.AdvertiseAddr, + &c.AdvertiseAddr, "admin.advertise-addr", "", ` @@ -175,72 +321,74 @@ If the bind address does not include an IP (such as ':8002') the nodes private IP will be used, such as a bind address of ':8002' may have an advertise address of '10.26.104.14:8002'.`, ) - c.Admin.TLS.RegisterFlags(fs, "admin") + c.TLS.RegisterFlags(fs, "admin") +} - fs.StringVar( - &c.Cluster.NodeID, - "cluster.node-id", - "", - ` -A unique identifier for the node in the cluster. +type Config struct { + Cluster ClusterConfig `json:"cluster" yaml:"cluster"` -By default a random ID will be generated for the node.`, - ) - fs.StringVar( - &c.Cluster.NodeIDPrefix, - "cluster.node-id-prefix", - "", - ` -A prefix for the node ID. + Proxy ProxyConfig `json:"proxy" yaml:"proxy"` -Piko will generate a unique random identifier for the node and append it to -the given prefix. + Upstream UpstreamConfig `json:"upstream" yaml:"upstream"` -Such as you could use the node or pod name as a prefix, then add a unique -identifier to ensure the node ID is unique across restarts.`, - ) - fs.StringSliceVar( - &c.Cluster.Join, - "cluster.join", - nil, - ` -A list of addresses of members in the cluster to join. + Gossip gossip.Config `json:"gossip" yaml:"gossip"` -This may be either addresses of specific nodes, such as -'--cluster.join 10.26.104.14,10.26.104.75', or a domain that resolves to -the addresses of the nodes in the cluster (e.g. a Kubernetes headless -service), such as '--cluster.join piko.prod-piko-ns'. + Admin AdminConfig `json:"admin" yaml:"admin"` -Each address must include the host, and may optionally include a port. If no -port is given, the gossip port of this node is used. + Auth auth.Config `json:"auth" yaml:"auth"` -Note each node propagates membership information to the other known nodes, -so the initial set of configured members only needs to be a subset of nodes.`, - ) - fs.BoolVar( - &c.Cluster.AbortIfJoinFails, - "cluster.abort-if-join-fails", - true, - ` -Whether the server node should abort if it is configured with more than one -node to join (excluding itself) but fails to join any members.`, - ) + Log log.Config `json:"log" yaml:"log"` - c.Auth.RegisterFlags(fs) + // GracePeriod is the duration to gracefully shutdown the server. During + // the grace period, listeners and idle connections are closed, then waits + // for active requests to complete and closes their connections. + GracePeriod time.Duration `json:"grace_period" yaml:"grace_period"` +} + +func (c *Config) Validate() error { + if err := c.Cluster.Validate(); err != nil { + return fmt.Errorf("cluster: %w", err) + } + + if err := c.Proxy.Validate(); err != nil { + return fmt.Errorf("proxy: %w", err) + } + + if err := c.Upstream.Validate(); err != nil { + return fmt.Errorf("upstream: %w", err) + } + + if err := c.Gossip.Validate(); err != nil { + return fmt.Errorf("gossip: %w", err) + } + + if err := c.Admin.Validate(); err != nil { + return fmt.Errorf("admin: %w", err) + } + + if err := c.Log.Validate(); err != nil { + return fmt.Errorf("log: %w", err) + } + + if c.GracePeriod == 0 { + return fmt.Errorf("missing grace period") + } + + return nil +} + +func (c *Config) RegisterFlags(fs *pflag.FlagSet) { + c.Cluster.RegisterFlags(fs) + + c.Proxy.RegisterFlags(fs) + + c.Upstream.RegisterFlags(fs) c.Gossip.RegisterFlags(fs) - fs.BoolVar( - &c.Usage.Disable, - "usage.disable", - false, - ` -Whether to disable anonymous usage tracking. + c.Admin.RegisterFlags(fs) -The Piko server periodically sends an anonymous report to help understand how -Piko is being used. This report includes the Piko version, host OS, host -architecture, requests processed and upstreams registered.`, - ) + c.Auth.RegisterFlags(fs) c.Log.RegisterFlags(fs) @@ -252,7 +400,7 @@ architecture, requests processed and upstreams registered.`, Maximum duration after a shutdown signal is received (SIGTERM or SIGINT) to gracefully shutdown the server node before terminating. This includes handling in-progress HTTP requests, gracefully closing -connections to upstream listeners, announcing to the cluster the node is -leaving...`, +connections to upstream listeners and announcing to the cluster the node is +leaving.`, ) } diff --git a/server/config/proxy.go b/server/config/proxy.go deleted file mode 100644 index 843da56..0000000 --- a/server/config/proxy.go +++ /dev/null @@ -1,154 +0,0 @@ -package config - -import ( - "fmt" - "time" - - "github.com/spf13/pflag" -) - -// ProxyHTTPConfig contains generic configuration for the HTTP servers. -type ProxyHTTPConfig struct { - // ReadTimeout is the maximum duration for reading the entire - // request, including the body. A zero or negative value means - // there will be no timeout. - ReadTimeout time.Duration `json:"read_timeout" yaml:"read_timeout"` - - // ReadHeaderTimeout is the amount of time allowed to read - // request headers. - ReadHeaderTimeout time.Duration `json:"read_header_timeout" yaml:"read_header_timeout"` - - // WriteTimeout is the maximum duration before timing out - // writes of the response. - WriteTimeout time.Duration `json:"write_timeout" yaml:"write_timeout"` - - // IdleTimeout is the maximum amount of time to wait for the - // next request when keep-alives are enabled. - IdleTimeout time.Duration `json:"idle_timeout" yaml:"idle_timeout"` - - // MaxHeaderBytes controls the maximum number of bytes the - // server will read parsing the request header's keys and - // values, including the request line. - MaxHeaderBytes int `json:"max_header_bytes" yaml:"max_header_bytes"` -} - -func (c *ProxyHTTPConfig) RegisterFlags(fs *pflag.FlagSet, prefix string) { - if prefix == "" { - prefix = "http." - } else { - prefix = prefix + ".http." - } - - fs.DurationVar( - &c.ReadTimeout, - prefix+"read-timeout", - time.Second*10, - ` -The maximum duration for reading the entire request, including the body. A -zero or negative value means there will be no timeout.`, - ) - fs.DurationVar( - &c.ReadHeaderTimeout, - prefix+"read-header-timeout", - time.Second*10, - ` -The maximum duration for reading the request headers. If zero, -http.read-timeout is used.`, - ) - fs.DurationVar( - &c.WriteTimeout, - prefix+"write-timeout", - time.Second*10, - ` -The maximum duration before timing out writes of the response.`, - ) - fs.DurationVar( - &c.IdleTimeout, - prefix+"idle-timeout", - time.Minute*5, - ` -The maximum amount of time to wait for the next request when keep-alives are -enabled.`, - ) - fs.IntVar( - &c.MaxHeaderBytes, - prefix+"max-header-bytes", - 1<<20, - ` -The maximum number of bytes the server will read parsing the request header's -keys and values, including the request line.`, - ) -} - -type ProxyConfig struct { - // BindAddr is the address to bind to listen for incoming HTTP connections. - BindAddr string `json:"bind_addr" yaml:"bind_addr"` - - // AdvertiseAddr is the address to advertise to other nodes. - AdvertiseAddr string `json:"advertise_addr" yaml:"advertise_addr"` - - // GatewayTimeout is the timeout in seconds of forwarding requests to an - // upstream listener. - GatewayTimeout time.Duration `json:"gateway_timeout" yaml:"gateway_timeout"` - - HTTP ProxyHTTPConfig `json:"http" yaml:"http"` - - TLS TLSConfig `json:"tls" yaml:"tls"` -} - -func (c *ProxyConfig) Validate() error { - if c.BindAddr == "" { - return fmt.Errorf("missing bind addr") - } - if c.GatewayTimeout == 0 { - return fmt.Errorf("missing gateway timeout") - } - if err := c.TLS.Validate(); err != nil { - return fmt.Errorf("tls: %w", err) - } - - return nil -} - -func (c *ProxyConfig) RegisterFlags(fs *pflag.FlagSet) { - fs.StringVar( - &c.BindAddr, - "proxy.bind-addr", - ":8000", - ` -The host/port to listen for incoming proxy HTTP requests. - -If the host is unspecified it defaults to all listeners, such as -'--proxy.bind-addr :8000' will listen on '0.0.0.0:8000'`, - ) - fs.StringVar( - &c.AdvertiseAddr, - "proxy.advertise-addr", - "", - ` -Proxy listen address to advertise to other nodes in the cluster. This is the -address other nodes will used to forward proxy requests. - -Such as if the listen address is ':8000', the advertised address may be -'10.26.104.45:8000' or 'node1.cluster:8000'. - -By default, if the bind address includes an IP to bind to that will be used. -If the bind address does not include an IP (such as ':8000') the nodes -private IP will be used, such as a bind address of ':8000' may have an -advertise address of '10.26.104.14:8000'.`, - ) - fs.DurationVar( - &c.GatewayTimeout, - "proxy.gateway-timeout", - time.Second*15, - ` -The timeout when sending proxied requests to upstream listeners for forwarding -to other nodes in the cluster. - -If the upstream does not respond within the given timeout a -'504 Gateway Timeout' is returned to the client.`, - ) - - c.TLS.RegisterFlags(fs, "proxy") - c.HTTP.RegisterFlags(fs, "proxy") -} diff --git a/server/proxy/conn.go b/server/proxy/conn.go deleted file mode 100644 index 7bc4316..0000000 --- a/server/proxy/conn.go +++ /dev/null @@ -1,68 +0,0 @@ -package proxy - -import ( - "bufio" - "bytes" - "context" - "fmt" - "net/http" - - "github.com/andydunstall/piko/pkg/rpc" -) - -// Conn represents a connection to an upstream endpoint. -type Conn interface { - EndpointID() string - Addr() string - Request(ctx context.Context, r *http.Request) (*http.Response, error) -} - -// RPCConn represents a connection to an upstream endpoint using -// rpc.Stream to exchange messages. -type RPCConn struct { - endpointID string - stream rpc.Stream -} - -func NewRPCConn(endpointID string, stream rpc.Stream) *RPCConn { - return &RPCConn{ - endpointID: endpointID, - stream: stream, - } -} - -func (c *RPCConn) EndpointID() string { - return c.endpointID -} - -func (c *RPCConn) Addr() string { - return c.stream.Addr() -} - -func (c *RPCConn) Request( - ctx context.Context, - r *http.Request, -) (*http.Response, error) { - // Encode the HTTP request. - var buffer bytes.Buffer - if err := r.Write(&buffer); err != nil { - return nil, fmt.Errorf("encode http request: %w", err) - } - - // Forward the request via RPC. - b, err := c.stream.RPC(ctx, rpc.TypeProxyHTTP, buffer.Bytes()) - if err != nil { - return nil, fmt.Errorf("rpc: %w", err) - } - - httpResp, err := http.ReadResponse( - bufio.NewReader(bytes.NewReader(b)), r, - ) - if err != nil { - return nil, fmt.Errorf("decode http response: %w", err) - } - - return httpResp, nil -} - -var _ Conn = &RPCConn{} diff --git a/server/proxy/conn_test.go b/server/proxy/conn_test.go deleted file mode 100644 index 3dd5f88..0000000 --- a/server/proxy/conn_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package proxy - -import ( - "bufio" - "bytes" - "context" - "io" - "net/http" - "net/url" - "testing" - "time" - - "github.com/andydunstall/piko/pkg/rpc" - "github.com/stretchr/testify/assert" -) - -type fakeStream struct { - rpcHandler func(rpcType rpc.Type, req []byte) ([]byte, error) -} - -func (s *fakeStream) Addr() string { - return "" -} - -func (s *fakeStream) RPC(_ context.Context, rpcType rpc.Type, req []byte) ([]byte, error) { - return s.rpcHandler(rpcType, req) -} - -func (s *fakeStream) Monitor( - _ context.Context, - _ time.Duration, - _ time.Duration, -) error { - return nil -} - -func (s *fakeStream) Close() error { - return nil -} - -func TestRPCStream(t *testing.T) { - rpcHandler := func(rpcType rpc.Type, req []byte) ([]byte, error) { - assert.Equal(t, rpc.TypeProxyHTTP, rpcType) - - httpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(req))) - assert.NoError(t, err) - assert.Equal(t, "/foo", httpReq.URL.Path) - - header := make(http.Header) - header.Add("h1", "v1") - header.Add("h2", "v2") - header.Add("h3", "v3") - body := bytes.NewReader([]byte("foo")) - httpResp := &http.Response{ - StatusCode: http.StatusOK, - Header: header, - Body: io.NopCloser(body), - } - - var buffer bytes.Buffer - assert.NoError(t, httpResp.Write(&buffer)) - - return buffer.Bytes(), nil - } - stream := &fakeStream{rpcHandler: rpcHandler} - - conn := NewRPCConn("my-endpoint", stream) - - resp, err := conn.Request(context.TODO(), &http.Request{ - URL: &url.URL{ - Path: "/foo", - }, - Host: "my-endpoint.piko.com:8000", - }) - assert.NoError(t, err) - - assert.Equal(t, http.StatusOK, resp.StatusCode) - - assert.Equal(t, "v1", resp.Header.Get("h1")) - assert.Equal(t, "v2", resp.Header.Get("h2")) - assert.Equal(t, "v3", resp.Header.Get("h3")) - - buf := new(bytes.Buffer) - //nolint - buf.ReadFrom(resp.Body) - assert.Equal(t, []byte("foo"), buf.Bytes()) -} diff --git a/server/proxy/local.go b/server/proxy/local.go deleted file mode 100644 index b91a31b..0000000 --- a/server/proxy/local.go +++ /dev/null @@ -1,144 +0,0 @@ -package proxy - -import ( - "context" - "net/http" - "sync" - - "github.com/andydunstall/piko/pkg/log" -) - -// localEndpoint contains the local connections for an endpoint ID. -type localEndpoint struct { - conns []Conn - nextIndex int -} - -func (e *localEndpoint) AddConn(c Conn) { - e.conns = append(e.conns, c) -} - -// RemoveConn removes the connection if it exists and returns whether there are -// any remaining connections fro the endpoint ID. -func (e *localEndpoint) RemoveConn(c Conn) bool { - for i := 0; i != len(e.conns); i++ { - if e.conns[i] != c { - continue - } - e.conns = append(e.conns[:i], e.conns[i+1:]...) - if len(e.conns) == 0 { - return true - } - e.nextIndex %= len(e.conns) - return false - } - return len(e.conns) == 0 -} - -// Next returns the next connection to the endpoint in a round-robin fashion. -func (e *localEndpoint) Next() Conn { - if len(e.conns) == 0 { - return nil - } - - s := e.conns[e.nextIndex] - e.nextIndex++ - e.nextIndex %= len(e.conns) - return s -} - -// localProxy is responsible for forwarding requests to upstream endpoints -// connected to the local node. -type localProxy struct { - endpoints map[string]*localEndpoint - - mu sync.Mutex - - metrics *Metrics - - logger log.Logger -} - -func newLocalProxy(metrics *Metrics, logger log.Logger) *localProxy { - return &localProxy{ - endpoints: make(map[string]*localEndpoint), - metrics: metrics, - logger: logger, - } -} - -// Request attempts to forward the request to an upstream endpoint connected to -// the local node. -func (p *localProxy) Request( - ctx context.Context, - endpointID string, - r *http.Request, -) (*http.Response, error) { - conn := p.findConn(endpointID) - if conn == nil { - // No connection found. - return nil, errEndpointNotFound - } - - p.metrics.ForwardedLocalTotal.Inc() - - return conn.Request(ctx, r) -} - -func (p *localProxy) AddConn(conn Conn) { - p.mu.Lock() - defer p.mu.Unlock() - - e, ok := p.endpoints[conn.EndpointID()] - if !ok { - e = &localEndpoint{} - - p.metrics.RegisteredEndpoints.Inc() - } - - e.AddConn(conn) - p.endpoints[conn.EndpointID()] = e - - p.metrics.ConnectedUpstreams.Inc() -} - -func (p *localProxy) RemoveConn(conn Conn) { - p.mu.Lock() - defer p.mu.Unlock() - - endpoint, ok := p.endpoints[conn.EndpointID()] - if !ok { - return - } - if endpoint.RemoveConn(conn) { - delete(p.endpoints, conn.EndpointID()) - - p.metrics.RegisteredEndpoints.Dec() - } - - p.metrics.ConnectedUpstreams.Dec() -} - -func (p *localProxy) ConnAddrs() map[string][]string { - p.mu.Lock() - defer p.mu.Unlock() - - c := make(map[string][]string) - for endpointID, endpoint := range p.endpoints { - for _, conn := range endpoint.conns { - c[endpointID] = append(c[endpointID], conn.Addr()) - } - } - return c -} - -func (p *localProxy) findConn(endpointID string) Conn { - p.mu.Lock() - defer p.mu.Unlock() - - endpoint, ok := p.endpoints[endpointID] - if !ok { - return nil - } - return endpoint.Next() -} diff --git a/server/proxy/local_test.go b/server/proxy/local_test.go deleted file mode 100644 index 3e53d83..0000000 --- a/server/proxy/local_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package proxy - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestLocalEndpoint(t *testing.T) { - endpoint := &localEndpoint{} - - assert.Nil(t, endpoint.Next()) - - conn1 := &fakeConn{addr: "1"} - endpoint.AddConn(conn1) - assert.Equal(t, "1", endpoint.Next().Addr()) - - conn2 := &fakeConn{addr: "2"} - conn3 := &fakeConn{addr: "3"} - conn4 := &fakeConn{addr: "4"} - endpoint.AddConn(conn2) - endpoint.AddConn(conn3) - endpoint.AddConn(conn4) - - assert.Equal(t, "1", endpoint.Next().Addr()) - assert.Equal(t, "2", endpoint.Next().Addr()) - assert.Equal(t, "3", endpoint.Next().Addr()) - assert.Equal(t, "4", endpoint.Next().Addr()) - assert.Equal(t, "1", endpoint.Next().Addr()) - assert.Equal(t, "2", endpoint.Next().Addr()) - assert.Equal(t, "3", endpoint.Next().Addr()) - - assert.False(t, endpoint.RemoveConn(conn2)) - assert.False(t, endpoint.RemoveConn(conn3)) - assert.Equal(t, "1", endpoint.Next().Addr()) - assert.Equal(t, "4", endpoint.Next().Addr()) - assert.Equal(t, "1", endpoint.Next().Addr()) - - assert.False(t, endpoint.RemoveConn(conn1)) - assert.True(t, endpoint.RemoveConn(conn4)) - - assert.Nil(t, endpoint.Next()) -} diff --git a/server/proxy/options.go b/server/proxy/options.go deleted file mode 100644 index d68fe8d..0000000 --- a/server/proxy/options.go +++ /dev/null @@ -1,46 +0,0 @@ -package proxy - -import ( - "github.com/andydunstall/piko/pkg/forwarder" - "github.com/andydunstall/piko/pkg/log" -) - -type options struct { - forwarder forwarder.Forwarder - logger log.Logger -} - -type Option interface { - apply(*options) -} - -func defaultOptions() options { - return options{ - forwarder: forwarder.NewForwarder(), - logger: log.NewNopLogger(), - } -} - -type forwarderOption struct { - Forwarder forwarder.Forwarder -} - -func (o forwarderOption) apply(opts *options) { - opts.forwarder = o.Forwarder -} - -func WithForwarder(f forwarder.Forwarder) Option { - return forwarderOption{Forwarder: f} -} - -type loggerOption struct { - Logger log.Logger -} - -func (o loggerOption) apply(opts *options) { - opts.logger = o.Logger -} - -func WithLogger(l log.Logger) Option { - return loggerOption{Logger: l} -} diff --git a/server/proxy/proxy.go b/server/proxy/proxy.go deleted file mode 100644 index 620304b..0000000 --- a/server/proxy/proxy.go +++ /dev/null @@ -1,237 +0,0 @@ -package proxy - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "io" - "net/http" - "strings" - "time" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/server/cluster" - "go.uber.org/atomic" - "go.uber.org/zap" -) - -var ( - errEndpointNotFound = errors.New("not endpoint found") -) - -type Usage struct { - Requests *atomic.Uint64 - Upstreams *atomic.Uint64 -} - -// Proxy is responsible for forwarding requests to upstream endpoints. -type Proxy struct { - local *localProxy - remote *remoteProxy - - usage *Usage - - metrics *Metrics - - logger log.Logger -} - -func NewProxy(clusterState *cluster.State, opts ...Option) *Proxy { - options := defaultOptions() - for _, opt := range opts { - opt.apply(&options) - } - - metrics := NewMetrics() - logger := options.logger.WithSubsystem("proxy") - return &Proxy{ - local: newLocalProxy(metrics, logger), - remote: newRemoteProxy(clusterState, options.forwarder, metrics, logger), - usage: &Usage{ - Requests: atomic.NewUint64(0), - Upstreams: atomic.NewUint64(0), - }, - metrics: metrics, - logger: logger, - } -} - -// Request forwards the given HTTP request to an upstream endpoint and returns -// the response. -// -// If the request fails returns a response with status: -// - Missing endpoint ID: 401 (Bad request) -// - Upstream unreachable: 503 (Service unavailable) -// - Timeout: 504 (Gateway timeout) -func (p *Proxy) Request( - ctx context.Context, - r *http.Request, -) *http.Response { - // Whether the request was forwarded from another Piko node. - forwarded := r.Header.Get("x-piko-forward") == "true" - - logger := p.logger.With( - zap.String("host", r.Host), - zap.String("method", r.Method), - zap.String("path", r.URL.Path), - zap.Bool("forwarded", forwarded), - ) - p.usage.Requests.Inc() - - endpointID := endpointIDFromRequest(r) - if endpointID == "" { - logger.Warn("request: missing endpoint id") - return errorResponse(http.StatusBadRequest, "missing piko endpoint id") - } - - logger = logger.With(zap.String("endpoint-id", endpointID)) - - start := time.Now() - - // Attempt to send to an endpoint connected to the local node. - resp, err := p.local.Request(ctx, endpointID, r) - if err == nil { - logger.Debug( - "request: forwarded to local conn", - zap.Duration("latency", time.Since(start)), - ) - return resp - } - if !errors.Is(err, errEndpointNotFound) { - if errors.Is(err, context.DeadlineExceeded) { - logger.Warn("request: endpoint timeout", zap.Error(err)) - - return errorResponse( - http.StatusGatewayTimeout, - "endpoint timeout", - ) - } - - logger.Warn("request: endpoint unreachable", zap.Error(err)) - return errorResponse( - http.StatusServiceUnavailable, - "endpoint unreachable", - ) - } - - // If the request is from another Piko node though we don't have a - // connection for the endpoint, we don't forward again but return an - // error. - if forwarded { - logger.Warn("request: endpoint not found") - return errorResponse(http.StatusServiceUnavailable, "endpoint not found") - } - - // Set the 'x-piko-forward' before forwarding to a remote node. - r.Header.Set("x-piko-forward", "true") - - // Attempt to send the request to a Piko node with a connection for the - // endpoint. - resp, err = p.remote.Request(ctx, endpointID, r) - if err == nil { - logger.Debug( - "request: forwarded to remote", - zap.Duration("latency", time.Since(start)), - ) - - return resp - } - if !errors.Is(err, errEndpointNotFound) { - if errors.Is(err, context.DeadlineExceeded) { - logger.Warn("request: endpoint timeout", zap.Error(err)) - - return errorResponse( - http.StatusGatewayTimeout, - "endpoint timeout", - ) - } - - logger.Warn("request: endpoint unreachable", zap.Error(err)) - return errorResponse( - http.StatusServiceUnavailable, - "endpoint unreachable", - ) - } - - logger.Warn("request: endpoint not found") - return errorResponse(http.StatusServiceUnavailable, "endpoint not found") -} - -// AddConn registers a connection for an endpoint. -func (p *Proxy) AddConn(conn Conn) { - p.logger.Info( - "add conn", - zap.String("endpoint-id", conn.EndpointID()), - zap.String("addr", conn.Addr()), - ) - - p.usage.Upstreams.Inc() - - p.local.AddConn(conn) - p.remote.AddConn(conn) -} - -// RemoveConn removes a connection for an endpoint. -func (p *Proxy) RemoveConn(conn Conn) { - p.logger.Info( - "remove conn", - zap.String("endpoint-id", conn.EndpointID()), - zap.String("addr", conn.Addr()), - ) - p.local.RemoveConn(conn) - p.remote.RemoveConn(conn) -} - -// ConnAddrs returns a mapping of endpoint ID to connection address for -// all local connected endpoints. -func (p *Proxy) ConnAddrs() map[string][]string { - return p.local.ConnAddrs() -} - -func (p *Proxy) Usage() *Usage { - return p.usage -} - -func (p *Proxy) Metrics() *Metrics { - return p.metrics -} - -// endpointIDFromRequest returns the endpoint ID from the HTTP request, or an -// empty string if no endpoint ID is specified. -// -// This will check both the 'x-piko-endpoint' header and 'Host' header, where -// x-piko-endpoint takes precedence. -func endpointIDFromRequest(r *http.Request) string { - endpointID := r.Header.Get("x-piko-endpoint") - if endpointID != "" { - return endpointID - } - - host := r.Host - if host != "" && strings.Contains(host, ".") { - // If a host is given and contains a separator, use the bottom-level - // domain as the endpoint ID. - // - // Such as if the domain is 'xyz.piko.example.com', then 'xyz' is the - // endpoint ID. - return strings.Split(host, ".")[0] - } - - return "" -} - -type errorMessage struct { - Error string `json:"error"` -} - -func errorResponse(statusCode int, message string) *http.Response { - m := &errorMessage{ - Error: message, - } - b, _ := json.Marshal(m) - return &http.Response{ - StatusCode: statusCode, - Body: io.NopCloser(bytes.NewReader(b)), - } -} diff --git a/server/proxy/proxy_test.go b/server/proxy/proxy_test.go deleted file mode 100644 index cfccb15..0000000 --- a/server/proxy/proxy_test.go +++ /dev/null @@ -1,361 +0,0 @@ -package proxy - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - "testing" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/server/cluster" - "github.com/stretchr/testify/assert" -) - -type fakeConn struct { - endpointID string - addr string - handler func(r *http.Request) (*http.Response, error) -} - -func (c *fakeConn) EndpointID() string { - return c.endpointID -} - -func (c *fakeConn) Addr() string { - return c.addr -} - -func (c *fakeConn) Request( - _ context.Context, - r *http.Request, -) (*http.Response, error) { - return c.handler(r) -} - -type fakeForwarder struct { - handler func(addr string, r *http.Request) (*http.Response, error) -} - -func (f *fakeForwarder) Request( - _ context.Context, - addr string, - r *http.Request, -) (*http.Response, error) { - return f.handler(addr, r) -} - -func TestProxy(t *testing.T) { - t.Run("forward request remote ok", func(t *testing.T) { - networkMap := cluster.NewState(&cluster.Node{}, log.NewNopLogger()) - networkMap.AddNode(&cluster.Node{ - ID: "node-1", - Status: cluster.NodeStatusActive, - ProxyAddr: "1.2.3.4:1234", - Endpoints: map[string]int{ - "my-endpoint": 5, - }, - }) - - handler := func(addr string, r *http.Request) (*http.Response, error) { - assert.Equal(t, "1.2.3.4:1234", addr) - assert.Equal(t, "true", r.Header.Get("x-piko-forward")) - return &http.Response{ - StatusCode: http.StatusOK, - }, nil - } - forwarder := &fakeForwarder{ - handler: handler, - } - proxy := NewProxy(networkMap, WithForwarder(forwarder)) - - header := make(http.Header) - resp := proxy.Request(context.TODO(), &http.Request{ - URL: &url.URL{ - Path: "/foo", - }, - Host: "my-endpoint.piko.com", - Header: header, - }) - - assert.Equal(t, http.StatusOK, resp.StatusCode) - }) - - t.Run("forward request remote endpoint timeout", func(t *testing.T) { - networkMap := cluster.NewState(&cluster.Node{}, log.NewNopLogger()) - networkMap.AddNode(&cluster.Node{ - ID: "node-1", - Status: cluster.NodeStatusActive, - ProxyAddr: "1.2.3.4:1234", - Endpoints: map[string]int{ - "my-endpoint": 5, - }, - }) - - handler := func(addr string, r *http.Request) (*http.Response, error) { - return nil, fmt.Errorf("error: %w", context.DeadlineExceeded) - } - forwarder := &fakeForwarder{ - handler: handler, - } - proxy := NewProxy(networkMap, WithForwarder(forwarder)) - - header := make(http.Header) - resp := proxy.Request(context.TODO(), &http.Request{ - URL: &url.URL{ - Path: "/foo", - }, - Host: "my-endpoint.piko.com", - Header: header, - }) - - assert.Equal(t, http.StatusGatewayTimeout, resp.StatusCode) - - var m errorMessage - assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) - assert.Equal(t, "endpoint timeout", m.Error) - }) - - t.Run("forward request remote endpoint unreachable", func(t *testing.T) { - networkMap := cluster.NewState(&cluster.Node{}, log.NewNopLogger()) - networkMap.AddNode(&cluster.Node{ - ID: "node-1", - Status: cluster.NodeStatusActive, - ProxyAddr: "1.2.3.4:1234", - Endpoints: map[string]int{ - "my-endpoint": 5, - }, - }) - - handler := func(addr string, r *http.Request) (*http.Response, error) { - return nil, fmt.Errorf("unknown error") - } - forwarder := &fakeForwarder{ - handler: handler, - } - proxy := NewProxy(networkMap, WithForwarder(forwarder)) - - header := make(http.Header) - resp := proxy.Request(context.TODO(), &http.Request{ - URL: &url.URL{ - Path: "/foo", - }, - Host: "my-endpoint.piko.com", - Header: header, - }) - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - - var m errorMessage - assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) - assert.Equal(t, "endpoint unreachable", m.Error) - }) - - t.Run("forward request remote endpoint not found", func(t *testing.T) { - proxy := NewProxy( - cluster.NewState(&cluster.Node{}, log.NewNopLogger()), - ) - - header := make(http.Header) - resp := proxy.Request(context.TODO(), &http.Request{ - URL: &url.URL{ - Path: "/foo", - }, - Host: "my-endpoint.piko.com", - Header: header, - }) - - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - - var m errorMessage - assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) - assert.Equal(t, "endpoint not found", m.Error) - }) - - t.Run("forward local ok", func(t *testing.T) { - handler := func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - }, nil - } - conn := &fakeConn{ - endpointID: "my-endpoint", - addr: "1.1.1.1", - handler: handler, - } - - proxy := NewProxy( - cluster.NewState(&cluster.Node{}, log.NewNopLogger()), - ) - - proxy.AddConn(conn) - - resp := proxy.Request(context.TODO(), &http.Request{ - URL: &url.URL{ - Path: "/foo", - }, - Host: "my-endpoint.piko.com", - }) - assert.Equal(t, http.StatusOK, resp.StatusCode) - }) - - t.Run("forward request local endpoint timeout", func(t *testing.T) { - handler := func(r *http.Request) (*http.Response, error) { - return nil, fmt.Errorf("error: %w", context.DeadlineExceeded) - } - conn := &fakeConn{ - endpointID: "my-endpoint", - addr: "1.1.1.1", - handler: handler, - } - - proxy := NewProxy( - cluster.NewState(&cluster.Node{}, log.NewNopLogger()), - ) - - proxy.AddConn(conn) - - resp := proxy.Request(context.TODO(), &http.Request{ - URL: &url.URL{ - Path: "/foo", - }, - Host: "my-endpoint.piko.com", - }) - assert.Equal(t, http.StatusGatewayTimeout, resp.StatusCode) - - var m errorMessage - assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) - assert.Equal(t, "endpoint timeout", m.Error) - }) - - t.Run("forward request local endpoint unreachable", func(t *testing.T) { - handler := func(r *http.Request) (*http.Response, error) { - return nil, fmt.Errorf("unknown error") - } - conn := &fakeConn{ - endpointID: "my-endpoint", - addr: "1.1.1.1", - handler: handler, - } - - proxy := NewProxy( - cluster.NewState(&cluster.Node{}, log.NewNopLogger()), - ) - - proxy.AddConn(conn) - - resp := proxy.Request(context.TODO(), &http.Request{ - URL: &url.URL{ - Path: "/foo", - }, - Host: "my-endpoint.piko.com", - }) - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - - var m errorMessage - assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) - assert.Equal(t, "endpoint unreachable", m.Error) - }) - - t.Run("forward request local endpoint not found", func(t *testing.T) { - proxy := NewProxy( - cluster.NewState(&cluster.Node{}, log.NewNopLogger()), - ) - - header := make(http.Header) - // Set forward header to avoid being forwarded to a remote node. - header.Set("x-piko-forward", "true") - req := &http.Request{ - URL: &url.URL{ - Path: "/foo", - }, - Host: "my-endpoint.piko.com", - Header: header, - } - - resp := proxy.Request(context.TODO(), req) - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - var m errorMessage - assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) - assert.Equal(t, "endpoint not found", m.Error) - - conn := &fakeConn{ - endpointID: "my-endpoint", - } - proxy.AddConn(conn) - proxy.RemoveConn(conn) - - resp = proxy.Request(context.TODO(), req) - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) - assert.Equal(t, "endpoint not found", m.Error) - }) - - t.Run("add conn", func(t *testing.T) { - networkMap := cluster.NewState( - &cluster.Node{ - ID: "local", - }, log.NewNopLogger(), - ) - proxy := NewProxy(networkMap) - - conn := &fakeConn{ - endpointID: "my-endpoint", - } - proxy.AddConn(conn) - // Verify the cluster was updated. - assert.Equal(t, map[string]int{ - "my-endpoint": 1, - }, networkMap.LocalNode().Endpoints) - - proxy.RemoveConn(conn) - assert.Equal(t, 0, len(networkMap.LocalNode().Endpoints)) - }) - - t.Run("missing endpoint", func(t *testing.T) { - proxy := NewProxy( - cluster.NewState(&cluster.Node{}, log.NewNopLogger()), - ) - - resp := proxy.Request(context.TODO(), &http.Request{ - URL: &url.URL{ - Path: "/foo", - }, - Host: "localhost:9000", - }) - - assert.Equal(t, http.StatusBadRequest, resp.StatusCode) - - var m errorMessage - assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) - assert.Equal(t, "missing piko endpoint id", m.Error) - }) -} - -func TestEndpointIDFromRequest(t *testing.T) { - t.Run("host header", func(t *testing.T) { - endpointID := endpointIDFromRequest(&http.Request{ - Host: "my-endpoint.piko.com:9000", - }) - assert.Equal(t, "my-endpoint", endpointID) - }) - - t.Run("x-piko-endpoint header", func(t *testing.T) { - header := make(http.Header) - header.Add("x-piko-endpoint", "my-endpoint") - endpointID := endpointIDFromRequest(&http.Request{ - // Even though the host header is provided, 'x-piko-endpoint' - // takes precedence. - Host: "another-endpoint.piko.com:9000", - Header: header, - }) - assert.Equal(t, "my-endpoint", endpointID) - }) - - t.Run("no endpoint", func(t *testing.T) { - endpointID := endpointIDFromRequest(&http.Request{ - Host: "localhost:9000", - }) - assert.Equal(t, "", endpointID) - }) -} diff --git a/server/proxy/remote.go b/server/proxy/remote.go deleted file mode 100644 index 5570be5..0000000 --- a/server/proxy/remote.go +++ /dev/null @@ -1,74 +0,0 @@ -package proxy - -import ( - "context" - "net/http" - - "github.com/andydunstall/piko/pkg/forwarder" - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/server/cluster" - "github.com/prometheus/client_golang/prometheus" -) - -// remoteProxy is responsible for forwarding requests to Piko server nodes with -// an upstream connection for the target endpoint. -type remoteProxy struct { - clusterState *cluster.State - - forwarder forwarder.Forwarder - - metrics *Metrics - - logger log.Logger -} - -func newRemoteProxy( - clusterState *cluster.State, - forwarder forwarder.Forwarder, - metrics *Metrics, - logger log.Logger, -) *remoteProxy { - return &remoteProxy{ - clusterState: clusterState, - forwarder: forwarder, - metrics: metrics, - logger: logger, - } -} - -func (p *remoteProxy) Request( - ctx context.Context, - endpointID string, - r *http.Request, -) (*http.Response, error) { - nodeID, addr, ok := p.findNode(endpointID) - if !ok { - return nil, errEndpointNotFound - } - p.metrics.ForwardedRemoteTotal.With(prometheus.Labels{ - "node_id": nodeID, - }).Inc() - return p.forwarder.Request(ctx, addr, r) -} - -func (p *remoteProxy) AddConn(conn Conn) { - // Update the cluster to notify other nodes that we have a connection for - // the endpoint. - p.clusterState.AddLocalEndpoint(conn.EndpointID()) -} - -func (p *remoteProxy) RemoveConn(conn Conn) { - p.clusterState.RemoveLocalEndpoint(conn.EndpointID()) -} - -// findNode looks up a node with an upstream connection for the given endpoint -// and returns the node ID and proxy address. -func (p *remoteProxy) findNode(endpointID string) (string, string, bool) { - // TODO(andydunstall): This doesn't yet do any load balancing. It just - // selects the first node. - node, ok := p.clusterState.LookupEndpoint(endpointID) - if !ok { - return "", "", false - } - return node.ID, node.ProxyAddr, true -} diff --git a/server/proxy/reverseproxy.go b/server/proxy/reverseproxy.go new file mode 100644 index 0000000..25590bd --- /dev/null +++ b/server/proxy/reverseproxy.go @@ -0,0 +1,161 @@ +package proxy + +import ( + "context" + "encoding/json" + "errors" + "net" + "net/http" + "net/http/httputil" + "strings" + "time" + + "github.com/andydunstall/piko/pkg/log" + "github.com/andydunstall/piko/server/upstream" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type contextKey int + +const ( + endpointContextKey contextKey = iota + upstreamContextKey +) + +type ReverseProxy struct { + upstreams upstream.Manager + + proxy *httputil.ReverseProxy + + timeout time.Duration + + logger log.Logger +} + +func NewReverseProxy( + upstreams upstream.Manager, + timeout time.Duration, + logger log.Logger, +) *ReverseProxy { + rp := &ReverseProxy{ + upstreams: upstreams, + timeout: timeout, + logger: logger, + } + + rp.proxy = &httputil.ReverseProxy{ + Director: func(req *http.Request) { + req.URL.Scheme = "http" + req.URL.Host = req.Context().Value(endpointContextKey).(string) + + req.Header.Set("x-piko-forward", "true") + }, + Transport: &http.Transport{ + DialContext: rp.dialUpstream, + // 'connections' to the upstream are multiplexed over a single TCP + // connection so theres no overhead to creating new connections, + // therefore it doesn't make sense to keep them alive. + DisableKeepAlives: true, + }, + ErrorLog: logger.StdLogger(zapcore.WarnLevel), + ErrorHandler: rp.errorHandler, + } + + return rp +} + +func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if p.timeout != 0 { + ctx, cancel := context.WithTimeout(r.Context(), p.timeout) + defer cancel() + + r = r.WithContext(ctx) + } + + endpointID := EndpointIDFromRequest(r) + if endpointID == "" { + p.logger.Warn("request missing endpoint id") + + _ = errorResponse(w, http.StatusBadRequest, "missing endpoint id") + return + } + + ctx := context.WithValue(r.Context(), endpointContextKey, endpointID) + r = r.WithContext(ctx) + + // Whether the request was forwarded from another Piko node. + forwarded := r.Header.Get("x-piko-forward") == "true" + + // If there is a connected upstream, attempt to forward the request to one + // of those upstreams. Note this includes remote nodes that are reporting + // they have an available upstream. We don't allow multiple hops, so if + // forwarded is true we only select from local nodes. + upstream, ok := p.upstreams.Select(endpointID, !forwarded) + if !ok { + _ = errorResponse(w, http.StatusBadGateway, "no available upstreams") + return + } + + // Add the upstream to the context to pass to 'DialContext'. + ctx = context.WithValue(r.Context(), upstreamContextKey, upstream) + r = r.WithContext(ctx) + + p.proxy.ServeHTTP(w, r) +} + +func (p *ReverseProxy) dialUpstream(ctx context.Context, _, _ string) (net.Conn, error) { + // As a bit of a hack to work with http.Transport, we add the upstream + // to the dial context. + upstream := ctx.Value(upstreamContextKey).(upstream.Upstream) + return upstream.Dial() +} + +func (p *ReverseProxy) errorHandler(w http.ResponseWriter, _ *http.Request, err error) { + p.logger.Warn("proxy request", zap.Error(err)) + + if errors.Is(err, context.DeadlineExceeded) { + _ = errorResponse(w, http.StatusGatewayTimeout, "upstream timeout") + return + } + _ = errorResponse(w, http.StatusBadGateway, "upstream unreachable") +} + +type errorMessage struct { + Error string `json:"error"` +} + +func errorResponse(w http.ResponseWriter, statusCode int, message string) error { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(statusCode) + + m := &errorMessage{ + Error: message, + } + return json.NewEncoder(w).Encode(m) +} + +// EndpointIDFromRequest returns the endpoint ID from the HTTP request, or an +// empty string if no endpoint ID is specified. +// +// This will check both the 'x-piko-endpoint' header and 'Host' header, where +// x-piko-endpoint takes precedence. +func EndpointIDFromRequest(r *http.Request) string { + endpointID := r.Header.Get("x-piko-endpoint") + if endpointID != "" { + return endpointID + } + + host := r.Host + if host != "" && strings.Contains(host, ".") { + // If a host is given and contains a separator, use the bottom-level + // domain as the endpoint ID. + // + // Such as if the domain is 'xyz.piko.example.com', then 'xyz' is the + // endpoint ID. + return strings.Split(host, ".")[0] + } + + return "" +} diff --git a/server/proxy/reverseproxy_test.go b/server/proxy/reverseproxy_test.go new file mode 100644 index 0000000..74f4f97 --- /dev/null +++ b/server/proxy/reverseproxy_test.go @@ -0,0 +1,283 @@ +package proxy + +import ( + "bytes" + "encoding/json" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/andydunstall/piko/pkg/log" + "github.com/andydunstall/piko/server/upstream" + "github.com/stretchr/testify/assert" +) + +// import ( +// "bytes" +// "encoding/json" +// "io" +// "net/http" +// "net/http/httptest" +// "strings" +// "testing" +// "time" +// +// "github.com/andydunstall/piko/agent/config" +// "github.com/andydunstall/piko/pkg/log" +// "github.com/stretchr/testify/assert" +// ) + +type fakeManager struct { + handler func(endpointID string, allowForward bool) (upstream.Upstream, bool) +} + +func (m *fakeManager) Select( + endpointID string, + allowForward bool, +) (upstream.Upstream, bool) { + return m.handler(endpointID, allowForward) +} + +func (m *fakeManager) AddConn(_ upstream.Upstream) { +} + +func (m *fakeManager) RemoveConn(_ upstream.Upstream) { +} + +type tcpUpstream struct { + addr string +} + +func (u *tcpUpstream) Dial() (net.Conn, error) { + return net.Dial("tcp", u.addr) +} + +func (u *tcpUpstream) EndpointID() string { + return "my-endpoint" +} + +func TestReverseProxy_Forward(t *testing.T) { + t.Run("ok", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/foo/bar", r.URL.Path) + assert.Equal(t, "a=b", r.URL.RawQuery) + + buf := new(strings.Builder) + // nolint + io.Copy(buf, r.Body) + assert.Equal(t, "foo", buf.String()) + + // nolint + w.Write([]byte("bar")) + }, + )) + defer server.Close() + + proxy := NewReverseProxy( + &fakeManager{ + handler: func(endpointID string, allowForward bool) (upstream.Upstream, bool) { + return &tcpUpstream{ + addr: server.Listener.Addr().String(), + }, true + }, + }, + time.Second, + log.NewNopLogger(), + ) + + b := bytes.NewReader([]byte("foo")) + r := httptest.NewRequest(http.MethodGet, "/foo/bar?a=b", b) + r.Header.Add("x-piko-endpoint", "my-endpoint") + + w := httptest.NewRecorder() + proxy.ServeHTTP(w, r) + + resp := w.Result() + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + buf := new(strings.Builder) + // nolint + io.Copy(buf, resp.Body) + assert.Equal(t, "bar", buf.String()) + }) + + t.Run("timeout", func(t *testing.T) { + blockCh := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + <-blockCh + }, + )) + defer server.Close() + defer close(blockCh) + + proxy := NewReverseProxy( + &fakeManager{ + handler: func(endpointID string, allowForward bool) (upstream.Upstream, bool) { + return &tcpUpstream{ + addr: server.Listener.Addr().String(), + }, true + }, + }, + time.Millisecond, + log.NewNopLogger(), + ) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + + w := httptest.NewRecorder() + proxy.ServeHTTP(w, r) + + resp := w.Result() + defer resp.Body.Close() + + assert.Equal(t, http.StatusGatewayTimeout, resp.StatusCode) + + m := errorMessage{} + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) + assert.Equal(t, "upstream timeout", m.Error) + }) + + t.Run("upstream unreachable", func(t *testing.T) { + proxy := NewReverseProxy( + &fakeManager{ + handler: func(endpointID string, allowForward bool) (upstream.Upstream, bool) { + return &tcpUpstream{ + addr: "localhost:55555", + }, true + }, + }, + time.Second, + log.NewNopLogger(), + ) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Add("x-piko-endpoint", "my-endpoint") + + w := httptest.NewRecorder() + proxy.ServeHTTP(w, r) + + resp := w.Result() + defer resp.Body.Close() + + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) + + m := errorMessage{} + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) + assert.Equal(t, "upstream unreachable", m.Error) + }) + + t.Run("no available upstreams", func(t *testing.T) { + proxy := NewReverseProxy( + &fakeManager{ + handler: func(endpointID string, allowForward bool) (upstream.Upstream, bool) { + assert.Equal(t, "my-endpoint", endpointID) + assert.True(t, allowForward) + return nil, false + }, + }, + time.Second, + log.NewNopLogger(), + ) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Add("x-piko-endpoint", "my-endpoint") + + w := httptest.NewRecorder() + proxy.ServeHTTP(w, r) + + resp := w.Result() + defer resp.Body.Close() + + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) + + m := errorMessage{} + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) + assert.Equal(t, "no available upstreams", m.Error) + }) + + t.Run("no available upstreams forwarded", func(t *testing.T) { + proxy := NewReverseProxy( + &fakeManager{ + handler: func(endpointID string, allowForward bool) (upstream.Upstream, bool) { + assert.Equal(t, "my-endpoint", endpointID) + assert.False(t, allowForward) + return nil, false + }, + }, + time.Second, + log.NewNopLogger(), + ) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Add("x-piko-endpoint", "my-endpoint") + r.Header.Add("x-piko-forward", "true") + + w := httptest.NewRecorder() + proxy.ServeHTTP(w, r) + + resp := w.Result() + defer resp.Body.Close() + + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) + + m := errorMessage{} + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) + assert.Equal(t, "no available upstreams", m.Error) + }) + + t.Run("missing endpoint id", func(t *testing.T) { + proxy := NewReverseProxy(nil, time.Second, log.NewNopLogger()) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + // The host must have a '.' separator to be parsed as an endpoint ID. + r.Host = "foo" + + w := httptest.NewRecorder() + proxy.ServeHTTP(w, r) + + resp := w.Result() + defer resp.Body.Close() + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + + m := errorMessage{} + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) + assert.Equal(t, "missing endpoint id", m.Error) + }) +} + +func TestEndpointIDFromRequest(t *testing.T) { + t.Run("host header", func(t *testing.T) { + endpointID := EndpointIDFromRequest(&http.Request{ + Host: "my-endpoint.piko.com:9000", + }) + assert.Equal(t, "my-endpoint", endpointID) + }) + + t.Run("x-piko-endpoint header", func(t *testing.T) { + header := make(http.Header) + header.Add("x-piko-endpoint", "my-endpoint") + endpointID := EndpointIDFromRequest(&http.Request{ + // Even though the host header is provided, 'x-piko-endpoint' + // takes precedence. + Host: "another-endpoint.piko.com:9000", + Header: header, + }) + assert.Equal(t, "my-endpoint", endpointID) + }) + + t.Run("no endpoint", func(t *testing.T) { + endpointID := EndpointIDFromRequest(&http.Request{ + Host: "localhost:9000", + }) + assert.Equal(t, "", endpointID) + }) +} diff --git a/server/proxy/server.go b/server/proxy/server.go new file mode 100644 index 0000000..865f9f7 --- /dev/null +++ b/server/proxy/server.go @@ -0,0 +1,109 @@ +package proxy + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + + "github.com/andydunstall/piko/pkg/log" + "github.com/andydunstall/piko/pkg/middleware" + "github.com/andydunstall/piko/server/config" + "github.com/andydunstall/piko/server/upstream" + "github.com/gin-gonic/gin" + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type Server struct { + proxy *ReverseProxy + + httpServer *http.Server + + logger log.Logger +} + +func NewServer( + upstreams upstream.Manager, + proxyConfig config.ProxyConfig, + registry *prometheus.Registry, + tlsConfig *tls.Config, + logger log.Logger, +) *Server { + logger = logger.WithSubsystem("proxy") + + router := gin.New() + s := &Server{ + proxy: NewReverseProxy(upstreams, proxyConfig.Timeout, logger), + httpServer: &http.Server{ + Handler: router, + TLSConfig: tlsConfig, + ReadTimeout: proxyConfig.HTTP.ReadTimeout, + ReadHeaderTimeout: proxyConfig.HTTP.ReadHeaderTimeout, + WriteTimeout: proxyConfig.HTTP.WriteTimeout, + IdleTimeout: proxyConfig.HTTP.IdleTimeout, + MaxHeaderBytes: proxyConfig.HTTP.MaxHeaderBytes, + ErrorLog: logger.StdLogger(zapcore.WarnLevel), + }, + logger: logger, + } + + // Recover from panics. + router.Use(gin.CustomRecoveryWithWriter(nil, s.panicRoute)) + + router.Use(middleware.NewLogger(proxyConfig.AccessLog, logger)) + + metrics := middleware.NewMetrics("proxy") + if registry != nil { + metrics.Register(registry) + } + router.Use(metrics.Handler()) + + router.NoRoute(s.proxyRoute) + + return s +} + +func (s *Server) Serve(ln net.Listener) error { + s.logger.Info( + "starting proxy server", + zap.String("addr", ln.Addr().String()), + ) + + var err error + if s.httpServer.TLSConfig != nil { + err = s.httpServer.ServeTLS(ln, "", "") + } else { + err = s.httpServer.Serve(ln) + } + + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("http serve: %w", err) + } + return nil +} + +func (s *Server) Shutdown(ctx context.Context) error { + return s.httpServer.Shutdown(ctx) +} + +// proxyRoute handles proxied requests from proxy clients. +func (s *Server) proxyRoute(c *gin.Context) { + s.proxy.ServeHTTP(c.Writer, c.Request) +} + +func (s *Server) panicRoute(c *gin.Context, err any) { + s.logger.Error( + "handler panic", + zap.String("path", c.FullPath()), + zap.Any("err", err), + ) + c.AbortWithStatus(http.StatusInternalServerError) +} + +func init() { + // Disable Gin debug logs. + gin.SetMode(gin.ReleaseMode) +} diff --git a/server/proxy/status.go b/server/proxy/status.go deleted file mode 100644 index 443b9f5..0000000 --- a/server/proxy/status.go +++ /dev/null @@ -1,29 +0,0 @@ -package proxy - -import ( - "net/http" - - "github.com/andydunstall/piko/server/status" - "github.com/gin-gonic/gin" -) - -type Status struct { - proxy *Proxy -} - -func NewStatus(proxy *Proxy) *Status { - return &Status{ - proxy: proxy, - } -} - -func (s *Status) Register(group *gin.RouterGroup) { - group.GET("/endpoints", s.listEndpointsRoute) -} - -func (s *Status) listEndpointsRoute(c *gin.Context) { - endpoints := s.proxy.ConnAddrs() - c.JSON(http.StatusOK, endpoints) -} - -var _ status.Handler = &Status{} diff --git a/server/server.go b/server/server.go deleted file mode 100644 index 22c3e8f..0000000 --- a/server/server.go +++ /dev/null @@ -1,406 +0,0 @@ -package server - -import ( - "context" - "crypto/tls" - "fmt" - "net" - "os" - "strings" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/server/auth" - "github.com/andydunstall/piko/server/cluster" - "github.com/andydunstall/piko/server/config" - "github.com/andydunstall/piko/server/gossip" - "github.com/andydunstall/piko/server/proxy" - adminserver "github.com/andydunstall/piko/server/server/admin" - proxyserver "github.com/andydunstall/piko/server/server/proxy" - upstreamserver "github.com/andydunstall/piko/server/server/upstream" - "github.com/andydunstall/piko/server/usage" - "github.com/golang-jwt/jwt/v5" - "github.com/hashicorp/go-sockaddr" - rungroup "github.com/oklog/run" - "github.com/prometheus/client_golang/prometheus" - "go.uber.org/zap" -) - -// Server manages setting up and running a Piko server node. -type Server struct { - proxyLn net.Listener - proxyTLSConfig *tls.Config - - upstreamLn net.Listener - upstreamTLSConfig *tls.Config - - adminLn net.Listener - adminTLSConfig *tls.Config - - gossipStreamLn net.Listener - gossipPacketLn net.PacketConn - - conf *config.Config - - logger log.Logger -} - -func NewServer(conf *config.Config, logger log.Logger) (*Server, error) { - adminLn, err := net.Listen("tcp", conf.Admin.BindAddr) - if err != nil { - return nil, fmt.Errorf("admin listen: %s: %w", conf.Admin.BindAddr, err) - } - adminTLSConfig, err := conf.Admin.TLS.Load() - if err != nil { - return nil, fmt.Errorf("admin tls: %w", err) - } - - proxyLn, err := net.Listen("tcp", conf.Proxy.BindAddr) - if err != nil { - return nil, fmt.Errorf("proxy listen: %s: %w", conf.Proxy.BindAddr, err) - } - proxyTLSConfig, err := conf.Proxy.TLS.Load() - if err != nil { - return nil, fmt.Errorf("proxy tls: %w", err) - } - - upstreamLn, err := net.Listen("tcp", conf.Upstream.BindAddr) - if err != nil { - return nil, fmt.Errorf("upstream listen: %s: %w", conf.Upstream.BindAddr, err) - } - upstreamTLSConfig, err := conf.Upstream.TLS.Load() - if err != nil { - return nil, fmt.Errorf("upstream tls: %w", err) - } - - gossipStreamLn, err := net.Listen("tcp", conf.Gossip.BindAddr) - if err != nil { - return nil, fmt.Errorf("gossip listen: %s: %w", conf.Gossip.BindAddr, err) - } - - gossipPacketLn, err := net.ListenUDP("udp", &net.UDPAddr{ - IP: gossipStreamLn.Addr().(*net.TCPAddr).IP, - Port: gossipStreamLn.Addr().(*net.TCPAddr).Port, - }) - if err != nil { - return nil, fmt.Errorf("gossip listen: %s: %w", conf.Gossip.BindAddr, err) - } - - if conf.Cluster.NodeID == "" { - nodeID := cluster.GenerateNodeID() - if conf.Cluster.NodeIDPrefix != "" { - nodeID = conf.Cluster.NodeIDPrefix + nodeID - } - conf.Cluster.NodeID = nodeID - } - - // Incase the address has port 0, set the bind address to the listen - // address. - conf.Proxy.BindAddr = proxyLn.Addr().String() - conf.Upstream.BindAddr = upstreamLn.Addr().String() - conf.Admin.BindAddr = adminLn.Addr().String() - conf.Gossip.BindAddr = gossipStreamLn.Addr().String() - - if conf.Proxy.AdvertiseAddr == "" { - advertiseAddr, err := advertiseAddrFromBindAddr(conf.Proxy.BindAddr) - if err != nil { - logger.Error("invalid configuration", zap.Error(err)) - os.Exit(1) - } - conf.Proxy.AdvertiseAddr = advertiseAddr - } - if conf.Upstream.AdvertiseAddr == "" { - advertiseAddr, err := advertiseAddrFromBindAddr(conf.Upstream.BindAddr) - if err != nil { - logger.Error("invalid configuration", zap.Error(err)) - os.Exit(1) - } - conf.Upstream.AdvertiseAddr = advertiseAddr - } - if conf.Admin.AdvertiseAddr == "" { - advertiseAddr, err := advertiseAddrFromBindAddr(conf.Admin.BindAddr) - if err != nil { - logger.Error("invalid configuration", zap.Error(err)) - os.Exit(1) - } - conf.Admin.AdvertiseAddr = advertiseAddr - } - if conf.Gossip.AdvertiseAddr == "" { - advertiseAddr, err := advertiseAddrFromBindAddr(conf.Gossip.BindAddr) - if err != nil { - logger.Error("invalid configuration", zap.Error(err)) - os.Exit(1) - } - conf.Gossip.AdvertiseAddr = advertiseAddr - } - - return &Server{ - proxyLn: proxyLn, - proxyTLSConfig: proxyTLSConfig, - upstreamLn: upstreamLn, - upstreamTLSConfig: upstreamTLSConfig, - adminLn: adminLn, - adminTLSConfig: adminTLSConfig, - gossipStreamLn: gossipStreamLn, - gossipPacketLn: gossipPacketLn, - conf: conf, - logger: logger, - }, nil -} - -func (s *Server) Run(ctx context.Context) error { - var verifier auth.Verifier - if s.conf.Auth.AuthEnabled() { - verifierConf := auth.JWTVerifierConfig{ - HMACSecretKey: []byte(s.conf.Auth.TokenHMACSecretKey), - Audience: s.conf.Auth.TokenAudience, - Issuer: s.conf.Auth.TokenIssuer, - } - - if s.conf.Auth.TokenRSAPublicKey != "" { - rsaPublicKey, err := jwt.ParseRSAPublicKeyFromPEM( - []byte(s.conf.Auth.TokenRSAPublicKey), - ) - if err != nil { - return fmt.Errorf("parse rsa public key: %w", err) - } - verifierConf.RSAPublicKey = rsaPublicKey - } - if s.conf.Auth.TokenECDSAPublicKey != "" { - ecdsaPublicKey, err := jwt.ParseECPublicKeyFromPEM( - []byte(s.conf.Auth.TokenECDSAPublicKey), - ) - if err != nil { - return fmt.Errorf("parse ecdsa public key: %w", err) - } - verifierConf.ECDSAPublicKey = ecdsaPublicKey - } - verifier = auth.NewJWTVerifier(verifierConf) - } - - s.logger.Info("starting piko server") - - registry := prometheus.NewRegistry() - - clusterState := cluster.NewState(&cluster.Node{ - ID: s.conf.Cluster.NodeID, - ProxyAddr: s.conf.Proxy.AdvertiseAddr, - AdminAddr: s.conf.Admin.AdvertiseAddr, - }, s.logger) - clusterState.Metrics().Register(registry) - - gossiper := gossip.NewGossip( - clusterState, - s.gossipStreamLn, - s.gossipPacketLn, - &s.conf.Gossip, - s.logger, - ) - defer gossiper.Close() - gossiper.Metrics().Register(registry) - - // Attempt to join an existing cluster. - // - // Note when running on Kubernetes, if this is the first member, as it is - // not yet ready the service DNS record won't resolve so this may fail. - // Therefore we attempt to join though continue booting if join fails. - // Once booted we then attempt to join again with retries. - nodeIDs, err := gossiper.JoinOnBoot(s.conf.Cluster.Join) - if err != nil { - s.logger.Warn("failed to join cluster", zap.Error(err)) - } - if len(nodeIDs) > 0 { - s.logger.Info( - "joined cluster", - zap.Strings("node-ids", nodeIDs), - ) - } - - p := proxy.NewProxy(clusterState, proxy.WithLogger(s.logger)) - p.Metrics().Register(registry) - - adminServer := adminserver.NewServer( - s.adminLn, - clusterState, - s.adminTLSConfig, - registry, - s.logger, - ) - adminServer.AddStatus("/cluster", cluster.NewStatus(clusterState)) - adminServer.AddStatus("/gossip", gossip.NewStatus(gossiper)) - adminServer.AddStatus("/proxy", proxy.NewStatus(p)) - - proxyServer := proxyserver.NewServer( - s.proxyLn, - p, - &s.conf.Proxy, - s.proxyTLSConfig, - registry, - s.logger, - ) - - upstreamServer := upstreamserver.NewServer( - s.upstreamLn, - p, - verifier, - s.upstreamTLSConfig, - s.logger, - ) - - reporter := usage.NewReporter(p, s.logger) - - var group rungroup.Group - - // Termination handler. - shutdownCtx, shutdownCancel := context.WithCancel(ctx) - group.Add(func() error { - select { - case <-ctx.Done(): - case <-shutdownCtx.Done(): - } - return nil - }, func(error) { - shutdownCancel() - }) - - // Proxy server. - group.Add(func() error { - if err := proxyServer.Serve(); err != nil { - return fmt.Errorf("proxy server serve: %w", err) - } - return nil - }, func(error) { - shutdownCtx, cancel := context.WithTimeout( - context.Background(), - s.conf.GracePeriod, - ) - defer cancel() - - if err := proxyServer.Shutdown(shutdownCtx); err != nil { - s.logger.Warn("failed to gracefully shutdown proxy server", zap.Error(err)) - } - - s.logger.Info("proxy server shut down") - }) - - // Upstream server. - group.Add(func() error { - if err := upstreamServer.Serve(); err != nil { - return fmt.Errorf("upstream server serve: %w", err) - } - return nil - }, func(error) { - shutdownCtx, cancel := context.WithTimeout( - context.Background(), - s.conf.GracePeriod, - ) - defer cancel() - - if err := upstreamServer.Shutdown(shutdownCtx); err != nil { - s.logger.Warn("failed to gracefully shutdown upstream server", zap.Error(err)) - } - - s.logger.Info("upstream server shut down") - }) - - // Admin server. - group.Add(func() error { - if err := adminServer.Serve(); err != nil { - return fmt.Errorf("admin server serve: %w", err) - } - return nil - }, func(error) { - shutdownCtx, cancel := context.WithTimeout( - context.Background(), - s.conf.GracePeriod, - ) - defer cancel() - - if err := adminServer.Shutdown(shutdownCtx); err != nil { - s.logger.Warn("failed to gracefully shutdown server", zap.Error(err)) - } - - s.logger.Info("admin server shut down") - }) - - // Gossip. - gossipCtx, gossipCancel := context.WithCancel(ctx) - group.Add(func() error { - if len(nodeIDs) == 0 { - nodeIDs, err = gossiper.JoinOnStartup(gossipCtx, s.conf.Cluster.Join) - if err != nil { - if s.conf.Cluster.AbortIfJoinFails { - return fmt.Errorf("join on startup: %w", err) - } - s.logger.Warn("failed to join cluster", zap.Error(err)) - } - if len(nodeIDs) > 0 { - s.logger.Info( - "joined cluster", - zap.Strings("node-ids", nodeIDs), - ) - } - } - - <-gossipCtx.Done() - - leaveCtx, cancel := context.WithTimeout( - context.Background(), - s.conf.GracePeriod, - ) - defer cancel() - - // Leave as soon as we receive the shutdown signal to avoid receiving - // forward proxy requests. - if err := gossiper.Leave(leaveCtx); err != nil { - s.logger.Warn("failed to gracefully leave cluster", zap.Error(err)) - } else { - s.logger.Info("left cluster") - } - - return nil - }, func(error) { - gossipCancel() - }) - - if !s.conf.Usage.Disable { - // Usage. - usageCtx, usageCancel := context.WithCancel(ctx) - group.Add(func() error { - reporter.Run(usageCtx) - return nil - }, func(error) { - usageCancel() - }) - } - - if err := group.Run(); err != nil { - return err - } - - s.logger.Info("shutdown complete") - - return nil -} - -func advertiseAddrFromBindAddr(bindAddr string) (string, error) { - if strings.HasPrefix(bindAddr, ":") { - bindAddr = "0.0.0.0" + bindAddr - } - - host, port, err := net.SplitHostPort(bindAddr) - if err != nil { - return "", fmt.Errorf("invalid bind addr: %s: %w", bindAddr, err) - } - - if host == "0.0.0.0" || host == "::" { - ip, err := sockaddr.GetPrivateIP() - if err != nil { - return "", fmt.Errorf("get interface addr: %w", err) - } - if ip == "" { - return "", fmt.Errorf("no private ip found") - } - return ip + ":" + port, nil - } - return bindAddr, nil -} diff --git a/server/server/admin/server.go b/server/server/admin/server.go deleted file mode 100644 index 1c557f0..0000000 --- a/server/server/admin/server.go +++ /dev/null @@ -1,215 +0,0 @@ -package admin - -import ( - "context" - "crypto/tls" - "fmt" - "io" - "net" - "net/http" - "net/http/pprof" - "time" - - "github.com/andydunstall/piko/pkg/forwarder" - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/server/cluster" - "github.com/andydunstall/piko/server/server/middleware" - "github.com/andydunstall/piko/server/status" - "github.com/gin-gonic/gin" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -// Server is the admin HTTP server, which exposes endpoints for metrics, health -// and inspecting the node status. -type Server struct { - ln net.Listener - - router *gin.Engine - - httpServer *http.Server - - clusterState *cluster.State - - forwarder forwarder.Forwarder - - registry *prometheus.Registry - - logger log.Logger -} - -func NewServer( - ln net.Listener, - clusterState *cluster.State, - tlsConfig *tls.Config, - registry *prometheus.Registry, - logger log.Logger, -) *Server { - logger = logger.WithSubsystem("admin.server") - - router := gin.New() - server := &Server{ - ln: ln, - router: router, - httpServer: &http.Server{ - Addr: ln.Addr().String(), - Handler: router, - TLSConfig: tlsConfig, - ErrorLog: logger.StdLogger(zapcore.WarnLevel), - }, - clusterState: clusterState, - forwarder: forwarder.NewForwarder(), - registry: registry, - logger: logger, - } - - // Recover from panics. - server.router.Use(gin.CustomRecoveryWithWriter(nil, server.panicRoute)) - - server.router.Use(middleware.NewLogger(logger)) - - metrics := middleware.NewMetrics("admin") - if registry != nil { - metrics.Register(registry) - } - router.Use(metrics.Handler()) - - if clusterState != nil { - router.Use(server.forwardInterceptor) - } - - server.registerRoutes() - - return server -} - -func (s *Server) AddStatus(route string, handler status.Handler) { - group := s.router.Group("/status").Group(route) - handler.Register(group) -} - -func (s *Server) Serve() error { - s.logger.Info("starting http server", zap.String("addr", s.ln.Addr().String())) - - var err error - if s.httpServer.TLSConfig != nil { - err = s.httpServer.ServeTLS(s.ln, "", "") - } else { - err = s.httpServer.Serve(s.ln) - } - - if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("http serve: %w", err) - } - return nil -} - -// Shutdown attempts to gracefully shutdown the server by waiting for pending -// requests to complete. -func (s *Server) Shutdown(ctx context.Context) error { - return s.httpServer.Shutdown(ctx) -} - -func (s *Server) Close() error { - return s.httpServer.Close() -} - -func (s *Server) registerRoutes() { - s.router.GET("/health", s.healthRoute) - s.router.GET("/ready", s.readyRoute) - - if s.registry != nil { - s.router.GET("/metrics", s.metricsHandler()) - } - - // From https://github.com/gin-contrib/pprof/blob/934af36b21728278339704005bcef2eec1375091/pprof.go#L32. - pprofGroup := s.router.Group("/debug/pprof") - pprofGroup.GET("/", gin.WrapF(pprof.Index)) - pprofGroup.GET("/cmdline", gin.WrapF(pprof.Cmdline)) - pprofGroup.GET("/profile", gin.WrapF(pprof.Profile)) - pprofGroup.POST("/symbol", gin.WrapF(pprof.Symbol)) - pprofGroup.GET("/symbol", gin.WrapF(pprof.Symbol)) - pprofGroup.GET("/trace", gin.WrapF(pprof.Trace)) - pprofGroup.GET("/allocs", gin.WrapH(pprof.Handler("allocs"))) - pprofGroup.GET("/block", gin.WrapH(pprof.Handler("block"))) - pprofGroup.GET("/goroutine", gin.WrapH(pprof.Handler("goroutine"))) - pprofGroup.GET("/heap", gin.WrapH(pprof.Handler("heap"))) - pprofGroup.GET("/mutex", gin.WrapH(pprof.Handler("mutex"))) - pprofGroup.GET("/threadcreate", gin.WrapH(pprof.Handler("threadcreate"))) -} - -func (s *Server) healthRoute(c *gin.Context) { - c.Status(http.StatusOK) -} - -func (s *Server) readyRoute(c *gin.Context) { - c.Status(http.StatusOK) -} - -func (s *Server) panicRoute(c *gin.Context, err any) { - s.logger.Error( - "handler panic", - zap.String("path", c.FullPath()), - zap.Any("err", err), - ) - c.AbortWithStatus(http.StatusInternalServerError) -} - -func (s *Server) metricsHandler() gin.HandlerFunc { - h := promhttp.HandlerFor( - s.registry, - promhttp.HandlerOpts{Registry: s.registry}, - ) - return func(c *gin.Context) { - h.ServeHTTP(c.Writer, c.Request) - } -} - -// forwardInterceptor intercepts all admin requests. If the request has a -// 'forward' query, the request is forwarded to the node with the requested ID. -func (s *Server) forwardInterceptor(c *gin.Context) { - forward, ok := c.GetQuery("forward") - if !ok || forward == s.clusterState.LocalID() { - // No forward configuration so handle locally. - c.Next() - return - } - - node, ok := s.clusterState.Node(forward) - if !ok { - c.AbortWithStatus(http.StatusNotFound) - return - } - - ctx, cancel := context.WithTimeout(c, time.Second*15) - defer cancel() - - resp, err := s.forwarder.Request(ctx, node.AdminAddr, c.Request) - if err != nil { - s.logger.Warn( - "forward admin request", - zap.String("forward-node-id", node.ID), - zap.String("forward-addr", node.AdminAddr), - zap.Error(err), - ) - c.AbortWithStatus(http.StatusInternalServerError) - return - } - - // Write the response status, headers and body. - for k, v := range resp.Header { - c.Writer.Header()[k] = v - } - c.Writer.WriteHeader(resp.StatusCode) - if _, err := io.Copy(c.Writer, resp.Body); err != nil { - s.logger.Warn("failed to write response", zap.Error(err)) - } - c.Abort() -} - -func init() { - // Disable Gin debug logs. - gin.SetMode(gin.ReleaseMode) -} diff --git a/server/server/admin/server_integration_test.go b/server/server/admin/server_integration_test.go deleted file mode 100644 index f684a21..0000000 --- a/server/server/admin/server_integration_test.go +++ /dev/null @@ -1,257 +0,0 @@ -//go:build integration - -package admin - -import ( - "bytes" - "context" - "crypto/tls" - "fmt" - "net" - "net/http" - "testing" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/pkg/testutil" - "github.com/andydunstall/piko/server/cluster" - "github.com/andydunstall/piko/server/status" - "github.com/gin-gonic/gin" - "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type fakeStatus struct { -} - -func (s *fakeStatus) Register(group *gin.RouterGroup) { - group.GET("/foo", s.fooRoute) -} - -func (s *fakeStatus) fooRoute(c *gin.Context) { - c.String(http.StatusOK, "foo") -} - -var _ status.Handler = &fakeStatus{} - -func TestServer_AdminRoutes(t *testing.T) { - adminLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - adminServer := NewServer( - adminLn, - nil, - nil, - prometheus.NewRegistry(), - log.NewNopLogger(), - ) - go func() { - require.NoError(t, adminServer.Serve()) - }() - defer adminServer.Shutdown(context.TODO()) - - t.Run("health", func(t *testing.T) { - url := fmt.Sprintf("http://%s/health", adminLn.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusOK, resp.StatusCode) - }) - - t.Run("metrics", func(t *testing.T) { - url := fmt.Sprintf("http://%s/metrics", adminLn.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusOK, resp.StatusCode) - }) - - t.Run("not found", func(t *testing.T) { - url := fmt.Sprintf("http://%s/foo", adminLn.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusNotFound, resp.StatusCode) - }) -} - -func TestServer_StatusRoutes(t *testing.T) { - adminLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - adminServer := NewServer( - adminLn, - nil, - nil, - prometheus.NewRegistry(), - log.NewNopLogger(), - ) - adminServer.AddStatus("/mystatus", &fakeStatus{}) - - go func() { - require.NoError(t, adminServer.Serve()) - }() - defer adminServer.Shutdown(context.TODO()) - - t.Run("status ok", func(t *testing.T) { - url := fmt.Sprintf("http://%s/status/mystatus/foo", adminLn.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusOK, resp.StatusCode) - - buf := new(bytes.Buffer) - //nolint - buf.ReadFrom(resp.Body) - assert.Equal(t, []byte("foo"), buf.Bytes()) - }) - - t.Run("not found", func(t *testing.T) { - url := fmt.Sprintf("http://%s/status/notfound", adminLn.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusNotFound, resp.StatusCode) - }) -} - -// TestServer_ForwardRequest tests forwarding an admin request to another node -// in the cluster. -func TestServer_ForwardRequest(t *testing.T) { - admin1Ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - cluster1State := cluster.NewState(&cluster.Node{ - ID: "node-1", - AdminAddr: admin1Ln.Addr().String(), - }, log.NewNopLogger()) - - admin1Server := NewServer( - admin1Ln, - cluster1State, - nil, - prometheus.NewRegistry(), - log.NewNopLogger(), - ) - // Note only node 1 registers the status route. - admin1Server.AddStatus("/mystatus", &fakeStatus{}) - - go func() { - require.NoError(t, admin1Server.Serve()) - }() - defer admin1Server.Shutdown(context.TODO()) - - admin2Ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - cluster2State := cluster.NewState(&cluster.Node{ - ID: "node-2", - AdminAddr: admin2Ln.Addr().String(), - }, log.NewNopLogger()) - cluster2State.AddNode(&cluster.Node{ - ID: "node-1", - AdminAddr: admin1Ln.Addr().String(), - }) - - admin2Server := NewServer( - admin2Ln, - cluster2State, - nil, - prometheus.NewRegistry(), - log.NewNopLogger(), - ) - - go func() { - require.NoError(t, admin2Server.Serve()) - }() - defer admin2Server.Shutdown(context.TODO()) - - t.Run("forward ok", func(t *testing.T) { - url := fmt.Sprintf("http://%s/status/mystatus/foo?forward=node-1", admin2Ln.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusOK, resp.StatusCode) - - buf := new(bytes.Buffer) - //nolint - buf.ReadFrom(resp.Body) - assert.Equal(t, []byte("foo"), buf.Bytes()) - }) - - t.Run("forward not found", func(t *testing.T) { - url := fmt.Sprintf("http://%s/status/mystatus/foo?forward=node-3", admin2Ln.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusNotFound, resp.StatusCode) - }) -} - -func TestServer_TLS(t *testing.T) { - rootCAPool, cert, err := testutil.LocalTLSServerCert() - require.NoError(t, err) - - adminLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - tlsConfig := &tls.Config{} - tlsConfig.Certificates = []tls.Certificate{cert} - - adminServer := NewServer( - adminLn, - nil, - tlsConfig, - prometheus.NewRegistry(), - log.NewNopLogger(), - ) - go func() { - require.NoError(t, adminServer.Serve()) - }() - defer adminServer.Shutdown(context.TODO()) - - t.Run("https ok", func(t *testing.T) { - tlsConfig = &tls.Config{ - RootCAs: rootCAPool, - } - transport := &http.Transport{ - TLSClientConfig: tlsConfig, - } - client := &http.Client{ - Transport: transport, - } - - req, _ := http.NewRequest( - http.MethodGet, - fmt.Sprintf("https://%s/health", adminLn.Addr().String()), - nil, - ) - resp, err := client.Do(req) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusOK, resp.StatusCode) - }) - - t.Run("https bad ca", func(t *testing.T) { - url := fmt.Sprintf("https://%s/health", adminLn.Addr().String()) - _, err := http.Get(url) - assert.ErrorContains(t, err, "certificate signed by unknown authority") - }) - - t.Run("http", func(t *testing.T) { - url := fmt.Sprintf("http://%s/health", adminLn.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusBadRequest, resp.StatusCode) - }) -} diff --git a/server/server/middleware/logger.go b/server/server/middleware/logger.go deleted file mode 100644 index 211e118..0000000 --- a/server/server/middleware/logger.go +++ /dev/null @@ -1,34 +0,0 @@ -package middleware - -import ( - "time" - - "github.com/andydunstall/piko/pkg/log" - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// NewLogger creates logging middleware that logs every request. -func NewLogger(logger log.Logger) gin.HandlerFunc { - logger = logger.WithSubsystem(logger.Subsystem() + ".route") - return func(c *gin.Context) { - start := time.Now() - path := c.Request.URL.Path - if c.Request.URL.RawQuery != "" { - path = path + "?" + c.Request.URL.RawQuery - } - - // Process request - c.Next() - - logger.Debug( - "http request", - zap.String("method", c.Request.Method), - zap.Int("status", c.Writer.Status()), - zap.String("path", path), - zap.Int64("latency", time.Since(start).Milliseconds()), - zap.String("client-ip", c.ClientIP()), - zap.Int("resp-size", c.Writer.Size()), - ) - } -} diff --git a/server/server/proxy/server.go b/server/server/proxy/server.go deleted file mode 100644 index 83f2b37..0000000 --- a/server/server/proxy/server.go +++ /dev/null @@ -1,152 +0,0 @@ -package server - -import ( - "context" - "crypto/tls" - "fmt" - "io" - "net" - "net/http" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/server/config" - "github.com/andydunstall/piko/server/server/middleware" - "github.com/gin-gonic/gin" - "github.com/prometheus/client_golang/prometheus" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -type Proxy interface { - Request(ctx context.Context, r *http.Request) *http.Response -} - -// Server is the HTTP server for the proxy, which proxies all incoming -// requests. -type Server struct { - ln net.Listener - - router *gin.Engine - - httpServer *http.Server - - proxy Proxy - - shutdownCtx context.Context - shutdownCancel func() - - conf *config.ProxyConfig - - logger log.Logger -} - -func NewServer( - ln net.Listener, - proxy Proxy, - conf *config.ProxyConfig, - tlsConfig *tls.Config, - registry *prometheus.Registry, - logger log.Logger, -) *Server { - logger = logger.WithSubsystem("proxy.server") - - shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) - - router := gin.New() - server := &Server{ - ln: ln, - router: router, - httpServer: &http.Server{ - Addr: ln.Addr().String(), - Handler: router, - TLSConfig: tlsConfig, - ReadTimeout: conf.HTTP.ReadTimeout, - ReadHeaderTimeout: conf.HTTP.ReadHeaderTimeout, - WriteTimeout: conf.HTTP.WriteTimeout, - IdleTimeout: conf.HTTP.IdleTimeout, - MaxHeaderBytes: conf.HTTP.MaxHeaderBytes, - ErrorLog: logger.StdLogger(zapcore.WarnLevel), - }, - shutdownCtx: shutdownCtx, - shutdownCancel: shutdownCancel, - proxy: proxy, - conf: conf, - logger: logger, - } - - // Recover from panics. - server.router.Use(gin.CustomRecoveryWithWriter(nil, server.panicRoute)) - - server.router.Use(middleware.NewLogger(logger)) - - metrics := middleware.NewMetrics("proxy") - if registry != nil { - metrics.Register(registry) - } - router.Use(metrics.Handler()) - - server.registerRoutes() - - return server -} - -func (s *Server) Serve() error { - s.logger.Info("starting http server", zap.String("addr", s.ln.Addr().String())) - var err error - if s.httpServer.TLSConfig != nil { - err = s.httpServer.ServeTLS(s.ln, "", "") - } else { - err = s.httpServer.Serve(s.ln) - } - - if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("http serve: %w", err) - } - return nil -} - -func (s *Server) Shutdown(ctx context.Context) error { - return s.httpServer.Shutdown(ctx) -} - -func (s *Server) registerRoutes() { - // Handle not found routes, which includes all proxied endpoints. - s.router.NoRoute(s.notFoundRoute) -} - -// proxyRoute handles proxied requests from proxy clients. -func (s *Server) proxyRoute(c *gin.Context) { - ctx, cancel := context.WithTimeout( - context.Background(), - s.conf.GatewayTimeout, - ) - defer cancel() - - resp := s.proxy.Request(ctx, c.Request) - // Write the response status, headers and body. - for k, v := range resp.Header { - c.Writer.Header()[k] = v - } - c.Writer.WriteHeader(resp.StatusCode) - if _, err := io.Copy(c.Writer, resp.Body); err != nil { - s.logger.Warn("failed to write response", zap.Error(err)) - } -} - -func (s *Server) notFoundRoute(c *gin.Context) { - s.proxyRoute(c) -} - -func (s *Server) panicRoute(c *gin.Context, err any) { - s.logger.Error( - "handler panic", - zap.String("path", c.FullPath()), - zap.Any("err", err), - ) - c.AbortWithStatus(http.StatusInternalServerError) -} - -func init() { - // Disable Gin debug logs. - gin.SetMode(gin.ReleaseMode) -} diff --git a/server/server/proxy/server_integration_test.go b/server/server/proxy/server_integration_test.go deleted file mode 100644 index 6b7cfd5..0000000 --- a/server/server/proxy/server_integration_test.go +++ /dev/null @@ -1,176 +0,0 @@ -//go:build integration - -package server - -import ( - "bytes" - "context" - "crypto/tls" - "fmt" - "io" - "net" - "net/http" - "testing" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/pkg/testutil" - "github.com/andydunstall/piko/server/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type fakeProxy struct { - handler func(ctx context.Context, r *http.Request) *http.Response -} - -func (p *fakeProxy) Request(ctx context.Context, r *http.Request) *http.Response { - return p.handler(ctx, r) -} - -func TestServer_ProxyRequest(t *testing.T) { - t.Run("forwarded", func(t *testing.T) { - handler := func(ctx context.Context, r *http.Request) *http.Response { - assert.Equal(t, "/foo/bar", r.URL.Path) - - header := make(http.Header) - header.Add("h1", "v1") - header.Add("h2", "v2") - header.Add("h3", "v3") - body := bytes.NewReader([]byte("foo")) - return &http.Response{ - StatusCode: http.StatusOK, - Header: header, - Body: io.NopCloser(body), - } - } - - proxyLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - proxyServer := NewServer( - proxyLn, - &fakeProxy{handler: handler}, - &config.ProxyConfig{}, - nil, - nil, - log.NewNopLogger(), - ) - go func() { - require.NoError(t, proxyServer.Serve()) - }() - defer proxyServer.Shutdown(context.TODO()) - - url := fmt.Sprintf("http://%s/foo/bar", proxyLn.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, "v1", resp.Header.Get("h1")) - assert.Equal(t, "v2", resp.Header.Get("h2")) - assert.Equal(t, "v3", resp.Header.Get("h3")) - - buf := new(bytes.Buffer) - //nolint - buf.ReadFrom(resp.Body) - assert.Equal(t, []byte("foo"), buf.Bytes()) - }) -} - -func TestServer_HandlePanic(t *testing.T) { - handler := func(ctx context.Context, r *http.Request) *http.Response { - panic("fail") - } - - proxyLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - proxyServer := NewServer( - proxyLn, - &fakeProxy{handler: handler}, - &config.ProxyConfig{}, - nil, - nil, - log.NewNopLogger(), - ) - go func() { - require.NoError(t, proxyServer.Serve()) - }() - defer proxyServer.Shutdown(context.TODO()) - - url := fmt.Sprintf("http://%s/foo/bar", proxyLn.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) -} - -func TestServer_TLS(t *testing.T) { - rootCAPool, cert, err := testutil.LocalTLSServerCert() - require.NoError(t, err) - - proxyLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - tlsConfig := &tls.Config{} - tlsConfig.Certificates = []tls.Certificate{cert} - - handler := func(ctx context.Context, r *http.Request) *http.Response { - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte("foo"))), - } - } - - proxyServer := NewServer( - proxyLn, - &fakeProxy{handler: handler}, - &config.ProxyConfig{}, - tlsConfig, - nil, - log.NewNopLogger(), - ) - go func() { - require.NoError(t, proxyServer.Serve()) - }() - defer proxyServer.Shutdown(context.TODO()) - - t.Run("https ok", func(t *testing.T) { - tlsConfig = &tls.Config{ - RootCAs: rootCAPool, - } - transport := &http.Transport{ - TLSClientConfig: tlsConfig, - } - client := &http.Client{ - Transport: transport, - } - - req, _ := http.NewRequest( - http.MethodGet, - fmt.Sprintf("https://%s", proxyLn.Addr().String()), - nil, - ) - resp, err := client.Do(req) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusOK, resp.StatusCode) - }) - - t.Run("https bad ca", func(t *testing.T) { - url := fmt.Sprintf("https://%s", proxyLn.Addr().String()) - _, err := http.Get(url) - assert.ErrorContains(t, err, "certificate signed by unknown authority") - }) - - t.Run("http", func(t *testing.T) { - url := fmt.Sprintf("http://%s", proxyLn.Addr().String()) - resp, err := http.Get(url) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusBadRequest, resp.StatusCode) - }) -} diff --git a/server/server/upstream/rpcserver.go b/server/server/upstream/rpcserver.go deleted file mode 100644 index a0b4c8c..0000000 --- a/server/server/upstream/rpcserver.go +++ /dev/null @@ -1,23 +0,0 @@ -package server - -import "github.com/andydunstall/piko/pkg/rpc" - -type rpcServer struct { - rpcHandler *rpc.Handler -} - -func newRPCServer() *rpcServer { - server := &rpcServer{ - rpcHandler: rpc.NewHandler(), - } - server.rpcHandler.Register(rpc.TypeHeartbeat, server.Heartbeat) - return server -} - -func (s *rpcServer) Handler() *rpc.Handler { - return s.rpcHandler -} - -func (s *rpcServer) Heartbeat(m []byte) []byte { - return m -} diff --git a/server/server/upstream/server_integration_test.go b/server/server/upstream/server_integration_test.go deleted file mode 100644 index f2293c7..0000000 --- a/server/server/upstream/server_integration_test.go +++ /dev/null @@ -1,307 +0,0 @@ -//go:build integration - -package server - -import ( - "context" - "crypto/tls" - "fmt" - "net" - "testing" - "time" - - "github.com/andydunstall/piko/pkg/conn/websocket" - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/pkg/rpc" - "github.com/andydunstall/piko/pkg/testutil" - "github.com/andydunstall/piko/server/auth" - proxy "github.com/andydunstall/piko/server/proxy" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type fakeProxy struct { - addUpstreamCh chan string - removeUpstreamCh chan string -} - -func newFakeProxy() *fakeProxy { - return &fakeProxy{ - addUpstreamCh: make(chan string), - removeUpstreamCh: make(chan string), - } -} - -func (p *fakeProxy) AddConn(conn proxy.Conn) { - p.addUpstreamCh <- conn.EndpointID() -} - -func (p *fakeProxy) RemoveConn(conn proxy.Conn) { - p.removeUpstreamCh <- conn.EndpointID() -} - -type fakeVerifier struct { - handler func(token string) (auth.EndpointToken, error) -} - -func (v *fakeVerifier) VerifyEndpointToken(token string) (auth.EndpointToken, error) { - return v.handler(token) -} - -var _ auth.Verifier = &fakeVerifier{} - -func TestServer_AddConn(t *testing.T) { - t.Run("ok", func(t *testing.T) { - upstreamLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - proxy := newFakeProxy() - upstreamServer := NewServer( - upstreamLn, - proxy, - nil, - nil, - log.NewNopLogger(), - ) - go func() { - require.NoError(t, upstreamServer.Serve()) - }() - defer upstreamServer.Shutdown(context.TODO()) - - url := fmt.Sprintf( - "ws://%s/piko/v1/listener/my-endpoint", - upstreamLn.Addr().String(), - ) - rpcServer := newRPCServer() - conn, err := websocket.Dial(context.TODO(), url) - require.NoError(t, err) - - // Add client stream and ensure upstream added to proxy. - stream := rpc.NewStream(conn, rpcServer.Handler(), log.NewNopLogger()) - assert.Equal(t, "my-endpoint", <-proxy.addUpstreamCh) - - // Close client stream and ensure upstream removed from proxy. - stream.Close() - assert.Equal(t, "my-endpoint", <-proxy.removeUpstreamCh) - }) - - t.Run("authenticated", func(t *testing.T) { - upstreamLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - verifier := &fakeVerifier{ - handler: func(token string) (auth.EndpointToken, error) { - assert.Equal(t, "123", token) - return auth.EndpointToken{ - Expiry: time.Now().Add(time.Hour), - Endpoints: []string{"my-endpoint"}, - }, nil - }, - } - - proxy := newFakeProxy() - upstreamServer := NewServer( - upstreamLn, - proxy, - verifier, - nil, - log.NewNopLogger(), - ) - go func() { - require.NoError(t, upstreamServer.Serve()) - }() - defer upstreamServer.Shutdown(context.TODO()) - - url := fmt.Sprintf( - "ws://%s/piko/v1/listener/my-endpoint", - upstreamLn.Addr().String(), - ) - rpcServer := newRPCServer() - conn, err := websocket.Dial(context.TODO(), url, websocket.WithToken("123")) - require.NoError(t, err) - - // Add client stream and ensure upstream added to proxy. - stream := rpc.NewStream(conn, rpcServer.Handler(), log.NewNopLogger()) - assert.Equal(t, "my-endpoint", <-proxy.addUpstreamCh) - - // Close client stream and ensure upstream removed from proxy. - stream.Close() - assert.Equal(t, "my-endpoint", <-proxy.removeUpstreamCh) - }) - - t.Run("token expires", func(t *testing.T) { - upstreamLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - verifier := &fakeVerifier{ - handler: func(token string) (auth.EndpointToken, error) { - assert.Equal(t, "123", token) - return auth.EndpointToken{ - // Set a short expiry as we wait for the token to expire. - Expiry: time.Now().Add(time.Millisecond * 10), - Endpoints: []string{"my-endpoint"}, - }, nil - }, - } - - proxy := newFakeProxy() - upstreamServer := NewServer( - upstreamLn, - proxy, - verifier, - nil, - log.NewNopLogger(), - ) - go func() { - require.NoError(t, upstreamServer.Serve()) - }() - defer upstreamServer.Shutdown(context.TODO()) - - url := fmt.Sprintf( - "ws://%s/piko/v1/listener/my-endpoint", - upstreamLn.Addr().String(), - ) - rpcServer := newRPCServer() - conn, err := websocket.Dial(context.TODO(), url, websocket.WithToken("123")) - require.NoError(t, err) - - // Add client stream and ensure upstream added to proxy. - stream := rpc.NewStream(conn, rpcServer.Handler(), log.NewNopLogger()) - defer stream.Close() - assert.Equal(t, "my-endpoint", <-proxy.addUpstreamCh) - - // Wait for the token to expire and the server should close the - // connection and remove it from the proxy. - assert.Equal(t, "my-endpoint", <-proxy.removeUpstreamCh) - }) - - t.Run("endpoint not permitted", func(t *testing.T) { - upstreamLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - verifier := &fakeVerifier{ - handler: func(token string) (auth.EndpointToken, error) { - assert.Equal(t, "123", token) - return auth.EndpointToken{ - Expiry: time.Now().Add(time.Hour), - Endpoints: []string{"foo"}, - }, nil - }, - } - - upstreamServer := NewServer( - upstreamLn, - nil, - verifier, - nil, - log.NewNopLogger(), - ) - go func() { - require.NoError(t, upstreamServer.Serve()) - }() - defer upstreamServer.Shutdown(context.TODO()) - - url := fmt.Sprintf( - "ws://%s/piko/v1/listener/my-endpoint", - upstreamLn.Addr().String(), - ) - _, err = websocket.Dial(context.TODO(), url, websocket.WithToken("123")) - require.Error(t, err) - }) - - t.Run("unauthenticated", func(t *testing.T) { - upstreamLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - verifier := &fakeVerifier{ - handler: func(token string) (auth.EndpointToken, error) { - assert.Equal(t, "123", token) - return auth.EndpointToken{}, auth.ErrInvalidToken - }, - } - - upstreamServer := NewServer( - upstreamLn, - nil, - verifier, - nil, - log.NewNopLogger(), - ) - go func() { - require.NoError(t, upstreamServer.Serve()) - }() - defer upstreamServer.Shutdown(context.TODO()) - - url := fmt.Sprintf( - "ws://%s/piko/v1/listener/my-endpoint", - upstreamLn.Addr().String(), - ) - _, err = websocket.Dial(context.TODO(), url, websocket.WithToken("123")) - require.Error(t, err) - }) -} - -func TestServer_TLS(t *testing.T) { - rootCAPool, cert, err := testutil.LocalTLSServerCert() - require.NoError(t, err) - - upstreamLn, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - tlsConfig := &tls.Config{} - tlsConfig.Certificates = []tls.Certificate{cert} - - proxy := newFakeProxy() - upstreamServer := NewServer( - upstreamLn, - proxy, - nil, - tlsConfig, - log.NewNopLogger(), - ) - go func() { - require.NoError(t, upstreamServer.Serve()) - }() - defer upstreamServer.Shutdown(context.TODO()) - - t.Run("wss ok", func(t *testing.T) { - url := fmt.Sprintf( - "wss://%s/piko/v1/listener/my-endpoint", - upstreamLn.Addr().String(), - ) - clientTLSConfig := &tls.Config{ - RootCAs: rootCAPool, - } - conn, err := websocket.Dial( - context.TODO(), url, websocket.WithTLSConfig(clientTLSConfig), - ) - require.NoError(t, err) - - // Add client stream and ensure upstream added to proxy. - rpcServer := newRPCServer() - stream := rpc.NewStream(conn, rpcServer.Handler(), log.NewNopLogger()) - assert.Equal(t, "my-endpoint", <-proxy.addUpstreamCh) - - // Close client stream and ensure upstream removed from proxy. - stream.Close() - assert.Equal(t, "my-endpoint", <-proxy.removeUpstreamCh) - }) - - t.Run("wss bad ca", func(t *testing.T) { - url := fmt.Sprintf( - "wss://%s/piko/v1/listener/my-endpoint", - upstreamLn.Addr().String(), - ) - _, err := websocket.Dial(context.TODO(), url) - require.ErrorContains(t, err, "certificate signed by unknown authority") - }) - - t.Run("ws", func(t *testing.T) { - url := fmt.Sprintf( - "ws://%s/piko/v1/listener/my-endpoint", - upstreamLn.Addr().String(), - ) - _, err := websocket.Dial(context.TODO(), url) - require.ErrorContains(t, err, "bad handshake") - }) -} diff --git a/server/status/client/client.go b/server/status/client/client.go new file mode 100644 index 0000000..c41849c --- /dev/null +++ b/server/status/client/client.go @@ -0,0 +1,54 @@ +package client + +import ( + "fmt" + "io" + "net/http" + "net/url" + fspath "path" + "time" +) + +type Client struct { + httpClient *http.Client + + url *url.URL +} + +func NewClient(url *url.URL) *Client { + return &Client{ + httpClient: &http.Client{ + Timeout: time.Second * 15, + }, + url: url, + } +} + +func (c *Client) SetURL(url *url.URL) { + c.url = url +} + +func (c *Client) Request(path string) (io.ReadCloser, error) { + url := new(url.URL) + *url = *c.url + + url.Path = fspath.Join(url.Path, path) + + req, err := http.NewRequest(http.MethodGet, url.String(), nil) + if err != nil { + return nil, fmt.Errorf("request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + + return nil, fmt.Errorf("request: bad status: %d", resp.StatusCode) + } + + return resp.Body, nil +} diff --git a/server/status/client/cluster.go b/server/status/client/cluster.go new file mode 100644 index 0000000..a3d3566 --- /dev/null +++ b/server/status/client/cluster.go @@ -0,0 +1,46 @@ +package client + +import ( + "encoding/json" + "fmt" + + "github.com/andydunstall/piko/server/cluster" +) + +type Cluster struct { + client *Client +} + +func NewCluster(client *Client) *Cluster { + return &Cluster{ + client: client, + } +} + +func (c *Cluster) Nodes() ([]*cluster.NodeMetadata, error) { + r, err := c.client.Request("/status/cluster/nodes") + if err != nil { + return nil, err + } + defer r.Close() + + var nodes []*cluster.NodeMetadata + if err := json.NewDecoder(r).Decode(&nodes); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + return nodes, nil +} + +func (c *Cluster) Node(nodeID string) (*cluster.Node, error) { + r, err := c.client.Request("/status/cluster/nodes/" + nodeID) + if err != nil { + return nil, err + } + defer r.Close() + + var node cluster.Node + if err := json.NewDecoder(r).Decode(&node); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + return &node, nil +} diff --git a/server/status/client/gossip.go b/server/status/client/gossip.go new file mode 100644 index 0000000..fec3e06 --- /dev/null +++ b/server/status/client/gossip.go @@ -0,0 +1,46 @@ +package client + +import ( + "encoding/json" + "fmt" + + "github.com/andydunstall/piko/pkg/gossip" +) + +type Gossip struct { + client *Client +} + +func NewGossip(client *Client) *Gossip { + return &Gossip{ + client: client, + } +} + +func (c *Gossip) Nodes() ([]gossip.NodeMetadata, error) { + r, err := c.client.Request("/status/gossip/nodes") + if err != nil { + return nil, err + } + defer r.Close() + + var nodes []gossip.NodeMetadata + if err := json.NewDecoder(r).Decode(&nodes); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + return nodes, nil +} + +func (c *Gossip) Node(nodeID string) (*gossip.NodeState, error) { + r, err := c.client.Request("/status/gossip/nodes/" + nodeID) + if err != nil { + return nil, err + } + defer r.Close() + + var node gossip.NodeState + if err := json.NewDecoder(r).Decode(&node); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + return &node, nil +} diff --git a/status/config/config.go b/server/status/config/config.go similarity index 75% rename from status/config/config.go rename to server/status/config/config.go index 5967026..3629cb1 100644 --- a/status/config/config.go +++ b/server/status/config/config.go @@ -24,8 +24,6 @@ func (c *ServerConfig) Validate() error { type Config struct { Server ServerConfig `json:"server"` - - Forward string `json:"forward"` } func (c *Config) Validate() error { @@ -42,16 +40,6 @@ func (c *Config) RegisterFlags(fs *pflag.FlagSet) { "http://localhost:8002", ` Piko server URL. This URL should point to the server admin port. -`, - ) - - fs.StringVar( - &c.Forward, - "forward", - "", - ` -Node ID to forward the request to. This can be useful when all nodes are behind -a load balancer and you want to inspect the status of a particular node. `, ) } diff --git a/server/server/middleware/auth.go b/server/upstream/auth.go similarity index 99% rename from server/server/middleware/auth.go rename to server/upstream/auth.go index f1e217e..d2f200b 100644 --- a/server/server/middleware/auth.go +++ b/server/upstream/auth.go @@ -1,4 +1,4 @@ -package middleware +package upstream import ( "errors" diff --git a/server/server/middleware/auth_test.go b/server/upstream/auth_test.go similarity index 99% rename from server/server/middleware/auth_test.go rename to server/upstream/auth_test.go index 18dff9e..e04e80f 100644 --- a/server/server/middleware/auth_test.go +++ b/server/upstream/auth_test.go @@ -1,4 +1,4 @@ -package middleware +package upstream import ( "encoding/json" diff --git a/server/upstream/manager.go b/server/upstream/manager.go new file mode 100644 index 0000000..0e0beac --- /dev/null +++ b/server/upstream/manager.go @@ -0,0 +1,150 @@ +package upstream + +import ( + "sync" + + "github.com/andydunstall/piko/server/cluster" + "github.com/prometheus/client_golang/prometheus" +) + +// Manager manages the upstream routes for each endpoint. +// +// This includes upstreams connected to the local node, or other server nodes +// in the cluster with a connected upstream for the target endpoint. +type Manager interface { + // Select looks up an upstream for the given endpoint ID. + // + // This will first look for an upstream connected to the local node, and + // load balance among the available connected upstreams. + // + // If there are no upstreams connected for the endpoint, and 'allowForward' + // is true, it will look for another node in the cluster that has an + // upstream connection for the endpoint and use that node as the upstream. + Select(endpointID string, allowForward bool) (Upstream, bool) + + // AddConn adds a local upstream connection. + AddConn(u Upstream) + + // RemoveConn removes a local upstream connection. + RemoveConn(u Upstream) +} + +// loadBalancer load balances requests among upstreams in a round-robin +// fashion. +type loadBalancer struct { + upstreams []Upstream + nextIndex int +} + +func (lb *loadBalancer) Add(u Upstream) { + lb.upstreams = append(lb.upstreams, u) +} + +func (lb *loadBalancer) Remove(u Upstream) bool { + for i := 0; i != len(lb.upstreams); i++ { + if lb.upstreams[i] != u { + continue + } + lb.upstreams = append(lb.upstreams[:i], lb.upstreams[i+1:]...) + if len(lb.upstreams) == 0 { + return true + } + lb.nextIndex %= len(lb.upstreams) + return false + } + return len(lb.upstreams) == 0 +} + +func (lb *loadBalancer) Next() Upstream { + if len(lb.upstreams) == 0 { + return nil + } + + u := lb.upstreams[lb.nextIndex] + lb.nextIndex++ + lb.nextIndex %= len(lb.upstreams) + return u +} + +type LoadBalancedManager struct { + localUpstreams map[string]*loadBalancer + + mu sync.Mutex + + cluster *cluster.State + + metrics *Metrics +} + +func NewLoadBalancedManager(cluster *cluster.State) *LoadBalancedManager { + return &LoadBalancedManager{ + localUpstreams: make(map[string]*loadBalancer), + cluster: cluster, + metrics: NewMetrics(), + } +} + +func (m *LoadBalancedManager) Select(endpointID string, allowRemote bool) (Upstream, bool) { + m.mu.Lock() + defer m.mu.Unlock() + + lb, ok := m.localUpstreams[endpointID] + if ok { + m.metrics.UpstreamRequestsTotal.Inc() + return lb.Next(), true + } + if !allowRemote { + return nil, false + } + + node, ok := m.cluster.LookupEndpoint(endpointID) + if !ok { + return nil, false + } + m.metrics.RemoteRequestsTotal.With(prometheus.Labels{ + "node_id": node.ID, + }).Inc() + return NewNodeUpstream(endpointID, node), true +} + +func (m *LoadBalancedManager) AddConn(u Upstream) { + m.mu.Lock() + defer m.mu.Unlock() + + lb, ok := m.localUpstreams[u.EndpointID()] + if !ok { + lb = &loadBalancer{} + + m.metrics.RegisteredEndpoints.Inc() + } + + lb.Add(u) + m.localUpstreams[u.EndpointID()] = lb + + m.cluster.AddLocalEndpoint(u.EndpointID()) + + m.metrics.ConnectedUpstreams.Inc() +} + +func (m *LoadBalancedManager) RemoveConn(u Upstream) { + m.mu.Lock() + defer m.mu.Unlock() + + lb, ok := m.localUpstreams[u.EndpointID()] + if !ok { + return + } + if lb.Remove(u) { + delete(m.localUpstreams, u.EndpointID()) + + m.metrics.RegisteredEndpoints.Dec() + } + + m.cluster.RemoveLocalEndpoint(u.EndpointID()) + + m.metrics.ConnectedUpstreams.Dec() +} + +func (m *LoadBalancedManager) Metrics() *Metrics { + return m.metrics +} diff --git a/serverv2/upstream/manager_test.go b/server/upstream/manager_test.go similarity index 97% rename from serverv2/upstream/manager_test.go rename to server/upstream/manager_test.go index 0957479..a7977a2 100644 --- a/serverv2/upstream/manager_test.go +++ b/server/upstream/manager_test.go @@ -20,7 +20,7 @@ func (u *fakeUpstream) Dial() (net.Conn, error) { } func TestLocalLoadBalancer(t *testing.T) { - lb := &localLoadBalancer{} + lb := &loadBalancer{} assert.Nil(t, lb.Next()) diff --git a/server/proxy/metrics.go b/server/upstream/metrics.go similarity index 56% rename from server/proxy/metrics.go rename to server/upstream/metrics.go index 868db74..5912043 100644 --- a/server/proxy/metrics.go +++ b/server/upstream/metrics.go @@ -1,4 +1,4 @@ -package proxy +package upstream import "github.com/prometheus/client_golang/prometheus" @@ -9,13 +9,13 @@ type Metrics struct { // RegisteredEndpoints is the number of endpoints registered to this node. RegisteredEndpoints prometheus.Gauge - // ForwardedLocalTotal is the number of requests forwarded to an upstream - // connected to the local node. - ForwardedLocalTotal prometheus.Counter + // UpstreamRequestsTotal is the number of requests sent to an + // upstream connected to the local node. + UpstreamRequestsTotal prometheus.Counter - // ForwardedRemoteTotal is the number of requests forwarded to a remote - // node. Labelled by target node ID. - ForwardedRemoteTotal *prometheus.CounterVec + // RemoteRequestsTotal is the number of requests sent to another node. + // Labelled by target node ID. + RemoteRequestsTotal *prometheus.CounterVec } func NewMetrics() *Metrics { @@ -23,7 +23,7 @@ func NewMetrics() *Metrics { ConnectedUpstreams: prometheus.NewGauge( prometheus.GaugeOpts{ Namespace: "piko", - Subsystem: "proxy", + Subsystem: "upstreams", Name: "connected_upstreams", Help: "Number of upstreams connected to this node", }, @@ -31,25 +31,25 @@ func NewMetrics() *Metrics { RegisteredEndpoints: prometheus.NewGauge( prometheus.GaugeOpts{ Namespace: "piko", - Subsystem: "proxy", + Subsystem: "upstreams", Name: "registered_endpoints", Help: "Number of endpoints registered to this node", }, ), - ForwardedLocalTotal: prometheus.NewCounter( + UpstreamRequestsTotal: prometheus.NewCounter( prometheus.CounterOpts{ Namespace: "piko", - Subsystem: "proxy", - Name: "forwarded_local_total", - Help: "Number of requests forwarded to an upstream connected to the local node", + Subsystem: "upstreams", + Name: "upstream_requests_total", + Help: "Number of requests sent to an upstream connected to the local node", }, ), - ForwardedRemoteTotal: prometheus.NewCounterVec( + RemoteRequestsTotal: prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: "piko", - Subsystem: "proxy", - Name: "forwarded_remote_total", - Help: "Number of requests forwarded to a remote node", + Subsystem: "upstreams", + Name: "remote_requests_total", + Help: "Number of requests sent to a remote node", }, []string{"node_id"}, ), @@ -60,7 +60,7 @@ func (m *Metrics) Register(registry *prometheus.Registry) { registry.MustRegister( m.ConnectedUpstreams, m.RegisteredEndpoints, - m.ForwardedLocalTotal, - m.ForwardedRemoteTotal, + m.UpstreamRequestsTotal, + m.RemoteRequestsTotal, ) } diff --git a/server/server/upstream/server.go b/server/upstream/server.go similarity index 54% rename from server/server/upstream/server.go rename to server/upstream/server.go index d8ceae4..84e3ef5 100644 --- a/server/server/upstream/server.go +++ b/server/upstream/server.go @@ -1,4 +1,4 @@ -package server +package upstream import ( "context" @@ -6,95 +6,72 @@ import ( "fmt" "net" "net/http" - "time" - pikowebsocket "github.com/andydunstall/piko/pkg/conn/websocket" "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/pkg/rpc" + "github.com/andydunstall/piko/pkg/mux" + pikowebsocket "github.com/andydunstall/piko/pkg/websocket" "github.com/andydunstall/piko/server/auth" - proxy "github.com/andydunstall/piko/server/proxy" - "github.com/andydunstall/piko/server/server/middleware" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) -type Proxy interface { - AddConn(conn proxy.Conn) - RemoveConn(conn proxy.Conn) -} - -// Server is the HTTP server upstream listeners to register endpoints. +// Server accepts connections from upstream services. type Server struct { - ln net.Listener - - router *gin.Engine + upstreams Manager httpServer *http.Server - rpcServer *rpcServer websocketUpgrader *websocket.Upgrader - proxy Proxy - - shutdownCtx context.Context - shutdownCancel func() - logger log.Logger } func NewServer( - ln net.Listener, - proxy Proxy, + upstreams Manager, verifier auth.Verifier, tlsConfig *tls.Config, logger log.Logger, ) *Server { - logger = logger.WithSubsystem("upstream.server") - - shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + logger = logger.WithSubsystem("admin") router := gin.New() server := &Server{ - ln: ln, - router: router, + upstreams: upstreams, httpServer: &http.Server{ - Addr: ln.Addr().String(), Handler: router, TLSConfig: tlsConfig, ErrorLog: logger.StdLogger(zapcore.WarnLevel), }, - rpcServer: newRPCServer(), websocketUpgrader: &websocket.Upgrader{}, - shutdownCtx: shutdownCtx, - shutdownCancel: shutdownCancel, - proxy: proxy, logger: logger, } - if verifier != nil { - tokenMiddleware := middleware.NewAuthMiddleware(verifier, logger) - router.Use(tokenMiddleware.VerifyEndpointToken) - } - // Recover from panics. - server.router.Use(gin.CustomRecoveryWithWriter(nil, server.panicRoute)) + router.Use(gin.CustomRecoveryWithWriter(nil, server.panicRoute)) - server.router.Use(middleware.NewLogger(logger)) + if verifier != nil { + authMiddleware := NewAuthMiddleware(verifier, logger) + router.Use(authMiddleware.VerifyEndpointToken) + } - server.registerRoutes() + server.registerRoutes(router) return server } -func (s *Server) Serve() error { - s.logger.Info("starting http server", zap.String("addr", s.ln.Addr().String())) +func (s *Server) Serve(ln net.Listener) error { + s.logger.Info( + "starting upstream server", + zap.String("addr", ln.Addr().String()), + ) + var err error if s.httpServer.TLSConfig != nil { - err = s.httpServer.ServeTLS(s.ln, "", "") + err = s.httpServer.ServeTLS(ln, "", "") } else { - err = s.httpServer.Serve(s.ln) + err = s.httpServer.Serve(ln) } if err != nil && err != http.ErrServerClosed { @@ -103,20 +80,17 @@ func (s *Server) Serve() error { return nil } +// Shutdown attempts to gracefully shutdown the server by waiting for pending +// requests to complete. func (s *Server) Shutdown(ctx context.Context) error { return s.httpServer.Shutdown(ctx) } -func (s *Server) registerRoutes() { - piko := s.router.Group("/piko/v1") - piko.GET("/listener/:endpointID", s.listenerRoute) -} - -// listenerRoute handles WebSocket connections from upstream listeners. -func (s *Server) listenerRoute(c *gin.Context) { +// upstreamRoute handles WebSocket connections from upstream services. +func (s *Server) upstreamRoute(c *gin.Context) { endpointID := c.Param("endpointID") - token, ok := c.Get(middleware.TokenContextKey) + token, ok := c.Get(TokenContextKey) if ok { endpointToken := token.(*auth.EndpointToken) if !endpointToken.EndpointPermitted(endpointID) { @@ -139,24 +113,21 @@ func (s *Server) listenerRoute(c *gin.Context) { s.logger.Warn("failed to upgrade websocket", zap.Error(err)) return } - stream := rpc.NewStream( - pikowebsocket.NewConn(wsConn), - s.rpcServer.Handler(), - s.logger, - ) - defer stream.Close() + conn := pikowebsocket.New(wsConn) + defer conn.Close() s.logger.Debug( - "listener connected", + "upstream connected", + zap.String("endpoint-id", endpointID), + zap.String("client-ip", c.ClientIP()), + ) + defer s.logger.Debug( + "upstream disconnected", zap.String("endpoint-id", endpointID), zap.String("client-ip", c.ClientIP()), ) - conn := proxy.NewRPCConn(endpointID, stream) - s.proxy.AddConn(conn) - defer s.proxy.RemoveConn(conn) - - ctx := s.shutdownCtx + ctx := context.Background() if ok { // If the token has an expiry, then we ensure we close the connection // to the endpoint once the token expires. @@ -168,15 +139,32 @@ func (s *Server) listenerRoute(c *gin.Context) { } } - if err := stream.Monitor( - ctx, - time.Second*10, - time.Second*10, - ); err != nil { - s.logger.Debug("listener disconnected", zap.Error(err)) + sess := mux.OpenServer(conn) + upstream := NewConnUpstream(endpointID, sess) + + s.upstreams.AddConn(upstream) + defer s.upstreams.RemoveConn(upstream) + + closedCh := make(chan struct{}) + go func() { + if err := sess.Wait(); err != nil { + s.logger.Warn("session closed", zap.Error(err)) + } + close(closedCh) + }() + + select { + case <-ctx.Done(): + s.logger.Warn("token expired") + case <-closedCh: } } +func (s *Server) registerRoutes(router *gin.Engine) { + piko := router.Group("/piko/v1") + piko.GET("/upstream/:endpointID", s.upstreamRoute) +} + func (s *Server) panicRoute(c *gin.Context, err any) { s.logger.Error( "handler panic", diff --git a/server/upstream/server_test.go b/server/upstream/server_test.go new file mode 100644 index 0000000..fda34a9 --- /dev/null +++ b/server/upstream/server_test.go @@ -0,0 +1,267 @@ +package upstream + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "testing" + "time" + + "github.com/andydunstall/piko/pkg/log" + "github.com/andydunstall/piko/pkg/testutil" + "github.com/andydunstall/piko/pkg/websocket" + "github.com/andydunstall/piko/server/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type fakeManager struct { + addConnCh chan Upstream + removeConnCh chan Upstream +} + +func newFakeManager() *fakeManager { + return &fakeManager{ + addConnCh: make(chan Upstream), + removeConnCh: make(chan Upstream), + } +} + +func (m *fakeManager) Select(_ string, _ bool) (Upstream, bool) { + return nil, false +} + +func (m *fakeManager) AddConn(u Upstream) { + m.addConnCh <- u +} + +func (m *fakeManager) RemoveConn(u Upstream) { + m.removeConnCh <- u +} + +func TestServer_Register(t *testing.T) { + t.Run("ok", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + manager := newFakeManager() + + s := NewServer(manager, nil, nil, log.NewNopLogger()) + go func() { + require.NoError(t, s.Serve(ln)) + }() + defer s.Shutdown(context.TODO()) + + url := fmt.Sprintf( + "ws://%s/piko/v1/upstream/my-endpoint", + ln.Addr().String(), + ) + conn, err := websocket.Dial(context.TODO(), url) + require.NoError(t, err) + + addedUpstream := <-manager.addConnCh + assert.Equal(t, "my-endpoint", addedUpstream.EndpointID()) + + conn.Close() + + removedUpstream := <-manager.removeConnCh + assert.Equal(t, "my-endpoint", removedUpstream.EndpointID()) + }) +} + +func TestServer_Authentication(t *testing.T) { + t.Run("ok", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + manager := newFakeManager() + + verifier := &fakeVerifier{ + handler: func(token string) (auth.EndpointToken, error) { + assert.Equal(t, "123", token) + return auth.EndpointToken{ + Expiry: time.Now().Add(time.Hour), + Endpoints: []string{"my-endpoint"}, + }, nil + }, + } + + s := NewServer(manager, verifier, nil, log.NewNopLogger()) + go func() { + require.NoError(t, s.Serve(ln)) + }() + defer s.Shutdown(context.TODO()) + + url := fmt.Sprintf( + "ws://%s/piko/v1/upstream/my-endpoint", + ln.Addr().String(), + ) + conn, err := websocket.Dial(context.TODO(), url, websocket.WithToken("123")) + require.NoError(t, err) + + addedUpstream := <-manager.addConnCh + assert.Equal(t, "my-endpoint", addedUpstream.EndpointID()) + + conn.Close() + + removedUpstream := <-manager.removeConnCh + assert.Equal(t, "my-endpoint", removedUpstream.EndpointID()) + }) + + t.Run("token expires", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + manager := newFakeManager() + + verifier := &fakeVerifier{ + handler: func(token string) (auth.EndpointToken, error) { + assert.Equal(t, "123", token) + return auth.EndpointToken{ + // Set a short expiry as we wait for the token to expire. + Expiry: time.Now().Add(time.Millisecond * 10), + Endpoints: []string{"my-endpoint"}, + }, nil + }, + } + + s := NewServer(manager, verifier, nil, log.NewNopLogger()) + go func() { + require.NoError(t, s.Serve(ln)) + }() + defer s.Shutdown(context.TODO()) + + url := fmt.Sprintf( + "ws://%s/piko/v1/upstream/my-endpoint", + ln.Addr().String(), + ) + conn, err := websocket.Dial(context.TODO(), url, websocket.WithToken("123")) + require.NoError(t, err) + defer conn.Close() + + addedUpstream := <-manager.addConnCh + assert.Equal(t, "my-endpoint", addedUpstream.EndpointID()) + + // Token should expire without closing client. + + removedUpstream := <-manager.removeConnCh + assert.Equal(t, "my-endpoint", removedUpstream.EndpointID()) + }) + + t.Run("endpoint not permitted", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + manager := newFakeManager() + + verifier := &fakeVerifier{ + handler: func(token string) (auth.EndpointToken, error) { + assert.Equal(t, "123", token) + return auth.EndpointToken{ + Expiry: time.Now().Add(time.Hour), + Endpoints: []string{"foo"}, + }, nil + }, + } + + s := NewServer(manager, verifier, nil, log.NewNopLogger()) + go func() { + require.NoError(t, s.Serve(ln)) + }() + defer s.Shutdown(context.TODO()) + + url := fmt.Sprintf( + "ws://%s/piko/v1/upstream/my-endpoint", + ln.Addr().String(), + ) + _, err = websocket.Dial(context.TODO(), url, websocket.WithToken("123")) + require.ErrorContains(t, err, "401: websocket: bad handshake") + }) + + t.Run("unauthenticated", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + manager := newFakeManager() + + verifier := &fakeVerifier{ + handler: func(token string) (auth.EndpointToken, error) { + assert.Equal(t, "123", token) + return auth.EndpointToken{}, auth.ErrInvalidToken + }, + } + + s := NewServer(manager, verifier, nil, log.NewNopLogger()) + go func() { + require.NoError(t, s.Serve(ln)) + }() + defer s.Shutdown(context.TODO()) + + url := fmt.Sprintf( + "ws://%s/piko/v1/upstream/my-endpoint", + ln.Addr().String(), + ) + _, err = websocket.Dial(context.TODO(), url, websocket.WithToken("123")) + require.ErrorContains(t, err, "401: websocket: bad handshake") + }) +} + +func TestServer_TLS(t *testing.T) { + rootCAPool, cert, err := testutil.LocalTLSServerCert() + require.NoError(t, err) + + tlsConfig := &tls.Config{} + tlsConfig.Certificates = []tls.Certificate{cert} + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + manager := newFakeManager() + + s := NewServer(manager, nil, tlsConfig, log.NewNopLogger()) + go func() { + require.NoError(t, s.Serve(ln)) + }() + defer s.Shutdown(context.TODO()) + + t.Run("wss ok", func(t *testing.T) { + url := fmt.Sprintf( + "wss://%s/piko/v1/upstream/my-endpoint", + ln.Addr().String(), + ) + clientTLSConfig := &tls.Config{ + RootCAs: rootCAPool, + } + conn, err := websocket.Dial( + context.TODO(), url, websocket.WithTLSConfig(clientTLSConfig), + ) + require.NoError(t, err) + + addedUpstream := <-manager.addConnCh + assert.Equal(t, "my-endpoint", addedUpstream.EndpointID()) + + conn.Close() + + removedUpstream := <-manager.removeConnCh + assert.Equal(t, "my-endpoint", removedUpstream.EndpointID()) + }) + + t.Run("wss bad ca", func(t *testing.T) { + url := fmt.Sprintf( + "wss://%s/piko/v1/upstream/my-endpoint", + ln.Addr().String(), + ) + _, err := websocket.Dial(context.TODO(), url) + require.ErrorContains(t, err, "certificate signed by unknown authority") + }) + + t.Run("ws", func(t *testing.T) { + url := fmt.Sprintf( + "ws://%s/piko/v1/upstream/my-endpoint", + ln.Addr().String(), + ) + _, err := websocket.Dial(context.TODO(), url) + require.ErrorContains(t, err, "bad handshake") + }) +} diff --git a/server/upstream/upstream.go b/server/upstream/upstream.go new file mode 100644 index 0000000..801eaf0 --- /dev/null +++ b/server/upstream/upstream.go @@ -0,0 +1,60 @@ +package upstream + +import ( + "net" + + "github.com/andydunstall/piko/pkg/mux" + "github.com/andydunstall/piko/server/cluster" +) + +// Upstream represents an upstream for a given endpoint. +// +// An upstream may be an upstream service connected to the local node, or +// another Piko server node. +type Upstream interface { + EndpointID() string + Dial() (net.Conn, error) +} + +// ConnUpstream represents a connection to an upstream service thats connected +// to the local node. +type ConnUpstream struct { + endpointID string + sess *mux.Session +} + +func NewConnUpstream(endpointID string, sess *mux.Session) *ConnUpstream { + return &ConnUpstream{ + endpointID: endpointID, + sess: sess, + } +} + +func (u *ConnUpstream) EndpointID() string { + return u.endpointID +} + +func (u *ConnUpstream) Dial() (net.Conn, error) { + return u.sess.Dial() +} + +// NodeUpstream represents a remote Piko server node. +type NodeUpstream struct { + endpointID string + node *cluster.Node +} + +func NewNodeUpstream(endpointID string, node *cluster.Node) *NodeUpstream { + return &NodeUpstream{ + endpointID: endpointID, + node: node, + } +} + +func (u *NodeUpstream) EndpointID() string { + return u.endpointID +} + +func (u *NodeUpstream) Dial() (net.Conn, error) { + return net.Dial("tcp", u.node.ProxyAddr) +} diff --git a/server/usage/reporter.go b/server/usage/reporter.go deleted file mode 100644 index 7a5dea5..0000000 --- a/server/usage/reporter.go +++ /dev/null @@ -1,105 +0,0 @@ -package usage - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "runtime" - "time" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/server/proxy" - "github.com/google/uuid" - "go.uber.org/zap" -) - -const ( - reportInterval = time.Hour -) - -type Report struct { - ID string `json:"id"` - OS string `json:"os"` - Arch string `json:"arch"` - Uptime int64 `json:"uptime"` - Requests uint64 `json:"requests"` - Upstreams uint64 `json:"upstreams"` -} - -// Reporter sends a periodic usage report. -type Reporter struct { - id string - start time.Time - proxy *proxy.Proxy - logger log.Logger -} - -func NewReporter(proxy *proxy.Proxy, logger log.Logger) *Reporter { - return &Reporter{ - id: uuid.New().String(), - start: time.Now(), - proxy: proxy, - logger: logger.WithSubsystem("reporter"), - } -} - -func (r *Reporter) Run(ctx context.Context) { - // Report on startup. - r.report() - - ticker := time.NewTicker(reportInterval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - // Report on shutdown. - r.report() - return - case <-ticker.C: - // Report on interval. - r.report() - } - } -} - -func (r *Reporter) report() { - report := &Report{ - ID: r.id, - OS: runtime.GOOS, - Arch: runtime.GOARCH, - Uptime: int64(time.Since(r.start).Seconds()), - Requests: r.proxy.Usage().Requests.Load(), - Upstreams: r.proxy.Usage().Upstreams.Load(), - } - if err := r.send(report); err != nil { - // Debug only as theres no user impact. - r.logger.Debug("failed to send usage report", zap.Error(err)) - } -} - -func (r *Reporter) send(report *Report) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - body, err := json.Marshal(report) - if err != nil { - return fmt.Errorf("marshal: %w", err) - } - req, err := http.NewRequestWithContext( - ctx, http.MethodPost, "http://report.pikoproxy.com/v1", bytes.NewBuffer(body), - ) - if err != nil { - return fmt.Errorf("request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("request: %w", err) - } - defer resp.Body.Close() - - return nil -} diff --git a/serverv2/reverseproxy/logger.go b/serverv2/reverseproxy/logger.go deleted file mode 100644 index 251c4b2..0000000 --- a/serverv2/reverseproxy/logger.go +++ /dev/null @@ -1,49 +0,0 @@ -package reverseproxy - -import ( - "net/http" - "time" - - "github.com/andydunstall/piko/pkg/log" - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -type loggedRequest struct { - Proto string `json:"proto"` - Method string `json:"method"` - Host string `json:"host"` - Path string `json:"path"` - RequestHeaders http.Header `json:"request_headers"` - ResponseHeaders http.Header `json:"response_headers"` - Status int `json:"status"` - Duration string `json:"duration"` -} - -// NewLoggerMiddleware creates logging middleware that logs every request. -func NewLoggerMiddleware(accessLog bool, logger log.Logger) gin.HandlerFunc { - logger = logger.WithSubsystem("reverseproxy.access") - return func(c *gin.Context) { - s := time.Now() - - c.Next() - - req := &loggedRequest{ - Proto: c.Request.Proto, - Method: c.Request.Method, - Host: c.Request.Method, - Path: c.Request.URL.Path, - RequestHeaders: c.Request.Header, - ResponseHeaders: c.Writer.Header(), - Status: c.Writer.Status(), - Duration: time.Since(s).String(), - } - if c.Writer.Status() > http.StatusInternalServerError { - logger.Warn("request", zap.Any("request", req)) - } else if accessLog { - logger.Info("request", zap.Any("request", req)) - } else { - logger.Debug("request", zap.Any("request", req)) - } - } -} diff --git a/serverv2/reverseproxy/reverseproxy.go b/serverv2/reverseproxy/reverseproxy.go deleted file mode 100644 index f686ac8..0000000 --- a/serverv2/reverseproxy/reverseproxy.go +++ /dev/null @@ -1,179 +0,0 @@ -package reverseproxy - -import ( - "context" - "encoding/json" - "io" - "net" - "net/http" - "strings" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/serverv2/upstream" - "go.uber.org/zap" -) - -type contextKey int - -const ( - upstreamContextKey contextKey = iota -) - -type UpstreamManager interface { - Select(endpointID string, allowForward bool) (upstream.Upstream, bool) -} - -type ReverseProxy struct { - upstreams UpstreamManager - - // upstreamTransport is the transport to forward requests to upstream - // connections. - upstreamTransport *http.Transport - - logger log.Logger -} - -func NewReverseProxy(upstreams UpstreamManager, logger log.Logger) *ReverseProxy { - proxy := &ReverseProxy{ - upstreams: upstreams, - logger: logger, - } - - proxy.upstreamTransport = &http.Transport{ - DialContext: proxy.dialUpstream, - // 'connections' to the upstream are multiplexed over a single TCP - // connection so theres no overhead to creating new connections, - // therefore it doesn't make sense to keep them alive. - DisableKeepAlives: true, - } - - return proxy -} - -func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Whether the request was forwarded from another Piko node. - forwarded := r.Header.Get("x-piko-forward") == "true" - - logger := p.logger.With( - zap.String("method", r.Method), - zap.String("host", r.Host), - zap.String("path", r.URL.Path), - zap.Bool("forwarded", forwarded), - ) - - // TODO(andydunstall): Add a timeout to ctx. - - endpointID := EndpointIDFromRequest(r) - if endpointID == "" { - logger.Warn("request missing endpoint id") - - if err := errorResponse( - w, http.StatusBadRequest, "missing endpoint id", - ); err != nil { - p.logger.Warn("failed to write error response", zap.Error(err)) - } - return - } - - logger = logger.With(zap.String("endpoint-id", endpointID)) - - r.Header.Add("x-piko-forward", "true") - - // If there is a connected upstream, attempt to forward the request to one - // of those upstreams. Note this includes remote nodes that are reporting - // they have an available upstream. We don't allow multiple hops, so if - // forwarded is true we only select from local nodes. - upstream, ok := p.upstreams.Select(endpointID, !forwarded) - if ok { - p.reverseProxyUpstream(w, r, upstream, logger) - return - } - - if err := errorResponse( - w, http.StatusBadGateway, "no available upstreams", - ); err != nil { - p.logger.Warn("failed to write error response", zap.Error(err)) - } -} - -func (p *ReverseProxy) reverseProxyUpstream( - w http.ResponseWriter, - r *http.Request, - upstream upstream.Upstream, - logger log.Logger, -) { - r.URL.Scheme = "http" - r.URL.Host = upstream.EndpointID() - - // Add the upstream to the context to pass to 'DialContext'. - ctx := context.WithValue(r.Context(), upstreamContextKey, upstream) - r = r.WithContext(ctx) - - resp, err := p.upstreamTransport.RoundTrip(r) - if err != nil { - logger.Warn("upstream unreachable", zap.Error(err)) - // TODO(andydunstall): Handle different error types. - if err := errorResponse( - w, http.StatusBadGateway, "upstream unreachable", - ); err != nil { - p.logger.Warn("failed to write error response", zap.Error(err)) - } - return - } - - // Write the response status, headers and body. - for k, v := range resp.Header { - w.Header()[k] = v - } - w.WriteHeader(resp.StatusCode) - if _, err := io.Copy(w, resp.Body); err != nil { - logger.Warn("failed to write response", zap.Error(err)) - return - } -} - -func (p *ReverseProxy) dialUpstream(ctx context.Context, _, _ string) (net.Conn, error) { - // As a bit of a hack to work with http.Transport, we add the upstream - // to the dial context. - upstream := ctx.Value(upstreamContextKey).(upstream.Upstream) - return upstream.Dial() -} - -type errorMessage struct { - Error string `json:"error"` -} - -func errorResponse(w http.ResponseWriter, statusCode int, message string) error { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.WriteHeader(statusCode) - - m := &errorMessage{ - Error: message, - } - return json.NewEncoder(w).Encode(m) -} - -// EndpointIDFromRequest returns the endpoint ID from the HTTP request, or an -// empty string if no endpoint ID is specified. -// -// This will check both the 'x-piko-endpoint' header and 'Host' header, where -// x-piko-endpoint takes precedence. -func EndpointIDFromRequest(r *http.Request) string { - endpointID := r.Header.Get("x-piko-endpoint") - if endpointID != "" { - return endpointID - } - - host := r.Host - if host != "" && strings.Contains(host, ".") { - // If a host is given and contains a separator, use the bottom-level - // domain as the endpoint ID. - // - // Such as if the domain is 'xyz.piko.example.com', then 'xyz' is the - // endpoint ID. - return strings.Split(host, ".")[0] - } - - return "" -} diff --git a/serverv2/reverseproxy/reverseproxy_test.go b/serverv2/reverseproxy/reverseproxy_test.go deleted file mode 100644 index 3097ff3..0000000 --- a/serverv2/reverseproxy/reverseproxy_test.go +++ /dev/null @@ -1,150 +0,0 @@ -package reverseproxy - -import ( - "bytes" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/serverv2/upstream" - "github.com/stretchr/testify/assert" -) - -type fakeManager struct { - handler func(endpointID string, allowForward bool) (upstream.Upstream, bool) -} - -func (m *fakeManager) Select( - endpointID string, - allowForward bool, -) (upstream.Upstream, bool) { - return m.handler(endpointID, allowForward) -} - -func TestReverseProxy(t *testing.T) { - // Tests forwarding a request to the upstream with a path, query and body, - // then checking the response is forwarded correctly. - t.Run("ok", func(t *testing.T) { - upstreamServer := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/foo/bar", r.URL.Path) - assert.Equal(t, "a=b", r.URL.RawQuery) - - buf := new(strings.Builder) - // nolint - io.Copy(buf, r.Body) - assert.Equal(t, "foo", buf.String()) - - // nolint - w.Write([]byte("bar")) - }, - )) - defer upstreamServer.Close() - - upstreamClient := upstream.NewTCPUpstream( - "my-endpoint", upstreamServer.Listener.Addr().String(), - ) - - proxy := NewReverseProxy(&fakeManager{ - handler: func(endpointID string, allowForward bool) (upstream.Upstream, bool) { - assert.Equal(t, "my-endpoint", endpointID) - assert.True(t, allowForward) - return upstreamClient, true - }, - }, log.NewNopLogger()) - - b := bytes.NewReader([]byte("foo")) - r := httptest.NewRequest(http.MethodGet, "/foo/bar?a=b", b) - r.Header.Add("x-piko-endpoint", "my-endpoint") - - w := httptest.NewRecorder() - proxy.ServeHTTP(w, r) - - resp := w.Result() - defer resp.Body.Close() - - assert.Equal(t, http.StatusOK, resp.StatusCode) - - buf := new(strings.Builder) - // nolint - io.Copy(buf, resp.Body) - assert.Equal(t, "bar", buf.String()) - }) - - t.Run("no available upstreams", func(t *testing.T) { - proxy := NewReverseProxy(&fakeManager{ - handler: func(endpointID string, allowForward bool) (upstream.Upstream, bool) { - assert.Equal(t, "my-endpoint", endpointID) - assert.True(t, allowForward) - return nil, false - }, - }, log.NewNopLogger()) - - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.Header.Add("x-piko-endpoint", "my-endpoint") - - w := httptest.NewRecorder() - proxy.ServeHTTP(w, r) - - resp := w.Result() - defer resp.Body.Close() - - assert.Equal(t, http.StatusBadGateway, resp.StatusCode) - - m := errorMessage{} - assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) - assert.Equal(t, "no available upstreams", m.Error) - }) - - t.Run("missing endpoint id", func(t *testing.T) { - proxy := NewReverseProxy(nil, log.NewNopLogger()) - - r := httptest.NewRequest(http.MethodGet, "/", nil) - // The host must have a '.' separator to be parsed as an endpoint ID. - r.Host = "foo" - - w := httptest.NewRecorder() - proxy.ServeHTTP(w, r) - - resp := w.Result() - defer resp.Body.Close() - - assert.Equal(t, http.StatusBadRequest, resp.StatusCode) - - m := errorMessage{} - assert.NoError(t, json.NewDecoder(resp.Body).Decode(&m)) - assert.Equal(t, "missing endpoint id", m.Error) - }) -} - -func TestEndpointIDFromRequest(t *testing.T) { - t.Run("host header", func(t *testing.T) { - endpointID := EndpointIDFromRequest(&http.Request{ - Host: "my-endpoint.piko.com:9000", - }) - assert.Equal(t, "my-endpoint", endpointID) - }) - - t.Run("x-piko-endpoint header", func(t *testing.T) { - header := make(http.Header) - header.Add("x-piko-endpoint", "my-endpoint") - endpointID := EndpointIDFromRequest(&http.Request{ - // Even though the host header is provided, 'x-piko-endpoint' - // takes precedence. - Host: "another-endpoint.piko.com:9000", - Header: header, - }) - assert.Equal(t, "my-endpoint", endpointID) - }) - - t.Run("no endpoint", func(t *testing.T) { - endpointID := EndpointIDFromRequest(&http.Request{ - Host: "localhost:9000", - }) - assert.Equal(t, "", endpointID) - }) -} diff --git a/serverv2/reverseproxy/upstream.go b/serverv2/reverseproxy/upstream.go deleted file mode 100644 index bc91aeb..0000000 --- a/serverv2/reverseproxy/upstream.go +++ /dev/null @@ -1,10 +0,0 @@ -package reverseproxy - -import ( - "net" -) - -// UpstreamPool contains the connected upstreams. -type UpstreamPool interface { - Dial(endpointID string) (net.Conn, error) -} diff --git a/serverv2/upstream/manager.go b/serverv2/upstream/manager.go deleted file mode 100644 index d656891..0000000 --- a/serverv2/upstream/manager.go +++ /dev/null @@ -1,95 +0,0 @@ -package upstream - -import ( - "sync" -) - -// localLoadBalancer load balances requests among upstreams connected to -// the local node. -type localLoadBalancer struct { - upstreams []Upstream - nextIndex int -} - -func (lb *localLoadBalancer) Add(u Upstream) { - lb.upstreams = append(lb.upstreams, u) -} - -func (lb *localLoadBalancer) Remove(u Upstream) bool { - for i := 0; i != len(lb.upstreams); i++ { - if lb.upstreams[i] != u { - continue - } - lb.upstreams = append(lb.upstreams[:i], lb.upstreams[i+1:]...) - if len(lb.upstreams) == 0 { - return true - } - lb.nextIndex %= len(lb.upstreams) - return false - } - return len(lb.upstreams) == 0 -} - -func (lb *localLoadBalancer) Next() Upstream { - if len(lb.upstreams) == 0 { - return nil - } - - u := lb.upstreams[lb.nextIndex] - lb.nextIndex++ - lb.nextIndex %= len(lb.upstreams) - return u -} - -// Manager manages the set of local upsteam services. -type Manager struct { - localLoadBalancers map[string]*localLoadBalancer - - mu sync.Mutex -} - -func NewManager() *Manager { - return &Manager{ - localLoadBalancers: make(map[string]*localLoadBalancer), - } -} - -func (m *Manager) Add(u Upstream) { - m.mu.Lock() - defer m.mu.Unlock() - - lb, ok := m.localLoadBalancers[u.EndpointID()] - if !ok { - lb = &localLoadBalancer{} - } - - lb.Add(u) - m.localLoadBalancers[u.EndpointID()] = lb -} - -func (m *Manager) Remove(u Upstream) { - m.mu.Lock() - defer m.mu.Unlock() - - lb, ok := m.localLoadBalancers[u.EndpointID()] - if !ok { - return - } - if lb.Remove(u) { - delete(m.localLoadBalancers, u.EndpointID()) - } -} - -func (m *Manager) Select(endpointID string, _ bool) (Upstream, bool) { - m.mu.Lock() - defer m.mu.Unlock() - - // TODO(andydunstall): If allowForward can select from another node, where - // Dial is just TCP - - lb, ok := m.localLoadBalancers[endpointID] - if !ok { - return nil, false - } - return lb.Next(), true -} diff --git a/serverv2/upstream/server.go b/serverv2/upstream/server.go deleted file mode 100644 index 3f068eb..0000000 --- a/serverv2/upstream/server.go +++ /dev/null @@ -1,144 +0,0 @@ -package upstream - -import ( - "context" - "crypto/tls" - "fmt" - "net" - "net/http" - "time" - - "github.com/andydunstall/piko/pkg/log" - pikowebsocket "github.com/andydunstall/piko/pkg/websocket" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" - "golang.ngrok.com/muxado/v2" -) - -// Server accepts connections from upstream services. -type Server struct { - manager *Manager - - router *gin.Engine - - httpServer *http.Server - - websocketUpgrader *websocket.Upgrader - - logger log.Logger -} - -func NewServer( - manager *Manager, - tlsConfig *tls.Config, - logger log.Logger, -) *Server { - router := gin.New() - server := &Server{ - manager: manager, - router: router, - httpServer: &http.Server{ - Handler: router, - TLSConfig: tlsConfig, - ErrorLog: logger.StdLogger(zapcore.WarnLevel), - }, - websocketUpgrader: &websocket.Upgrader{}, - logger: logger, - } - - // Recover from panics. - server.router.Use(gin.CustomRecoveryWithWriter(nil, server.panicRoute)) - - server.registerRoutes() - - return server -} - -func (s *Server) Serve(ln net.Listener) error { - s.logger.Info( - "starting http server", - zap.String("addr", ln.Addr().String()), - ) - var err error - if s.httpServer.TLSConfig != nil { - err = s.httpServer.ServeTLS(ln, "", "") - } else { - err = s.httpServer.Serve(ln) - } - - if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("http serve: %w", err) - } - return nil -} - -func (s *Server) Shutdown(ctx context.Context) error { - return s.httpServer.Shutdown(ctx) -} - -func (s *Server) registerRoutes() { - piko := s.router.Group("/piko/v1") - piko.GET("/upstream/:endpointID", s.wsRoute) -} - -// listenerRoute handles WebSocket connections from upstream services. -func (s *Server) wsRoute(c *gin.Context) { - wsConn, err := s.websocketUpgrader.Upgrade(c.Writer, c.Request, nil) - if err != nil { - // Upgrade replies to the client so nothing else to do. - s.logger.Warn("failed to upgrade websocket", zap.Error(err)) - return - } - conn := pikowebsocket.New(wsConn) - defer conn.Close() - - endpointID := c.Param("endpointID") - - s.logger.Debug( - "upstream connected", - zap.String("endpoint-id", endpointID), - zap.String("client-ip", c.ClientIP()), - ) - defer s.logger.Debug( - "upstream disconnected", - zap.String("endpoint-id", endpointID), - zap.String("client-ip", c.ClientIP()), - ) - - sess := muxado.NewTypedStreamSession(muxado.Server(conn, &muxado.Config{})) - heartbeat := muxado.NewHeartbeat( - sess, - func(d time.Duration, timeout bool) {}, - muxado.NewHeartbeatConfig(), - ) - - upstream := NewMuxUpstream(endpointID, sess) - s.manager.Add(upstream) - defer s.manager.Remove(upstream) - - for { - // The server doesn't yet accept streams, though need to keep accepting - // to respond to heartbeats and detect close. - _, err := heartbeat.AcceptStream() - if err != nil { - s.logger.Warn("accept stream", zap.Error(err)) - return - } - } -} - -func (s *Server) panicRoute(c *gin.Context, err any) { - s.logger.Error( - "handler panic", - zap.String("path", c.FullPath()), - zap.Any("err", err), - ) - c.AbortWithStatus(http.StatusInternalServerError) -} - -func init() { - // Disable Gin debug logs. - gin.SetMode(gin.ReleaseMode) -} diff --git a/serverv2/upstream/upstream.go b/serverv2/upstream/upstream.go deleted file mode 100644 index e52d93c..0000000 --- a/serverv2/upstream/upstream.go +++ /dev/null @@ -1,52 +0,0 @@ -package upstream - -import ( - "net" - - "golang.ngrok.com/muxado/v2" -) - -type Upstream interface { - EndpointID() string - Dial() (net.Conn, error) -} - -type TCPUpstream struct { - endpointID string - addr string -} - -func NewTCPUpstream(endpointID, addr string) *TCPUpstream { - return &TCPUpstream{ - endpointID: endpointID, - addr: addr, - } -} - -func (u *TCPUpstream) EndpointID() string { - return u.endpointID -} - -func (u *TCPUpstream) Dial() (net.Conn, error) { - return net.Dial("tcp", u.addr) -} - -type MuxUpstream struct { - endpointID string - sess muxado.TypedStreamSession -} - -func NewMuxUpstream(endpointID string, sess muxado.TypedStreamSession) *MuxUpstream { - return &MuxUpstream{ - endpointID: endpointID, - sess: sess, - } -} - -func (u *MuxUpstream) EndpointID() string { - return u.endpointID -} - -func (u *MuxUpstream) Dial() (net.Conn, error) { - return u.sess.OpenTypedStream(0) -} diff --git a/status/client/client.go b/status/client/client.go deleted file mode 100644 index cfc5efe..0000000 --- a/status/client/client.go +++ /dev/null @@ -1,135 +0,0 @@ -package client - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - fspath "path" - "time" - - "github.com/andydunstall/piko/pkg/gossip" - "github.com/andydunstall/piko/server/cluster" -) - -type Client struct { - httpClient *http.Client - - url *url.URL - - forward string -} - -func NewClient(url *url.URL, forward string) *Client { - return &Client{ - httpClient: &http.Client{ - Timeout: time.Second * 15, - }, - url: url, - forward: forward, - } -} - -func (c *Client) ProxyEndpoints() (map[string][]string, error) { - r, err := c.request("/status/proxy/endpoints") - if err != nil { - return nil, err - } - defer r.Close() - - var endpoints map[string][]string - if err := json.NewDecoder(r).Decode(&endpoints); err != nil { - return nil, fmt.Errorf("decode response: %w", err) - } - return endpoints, nil -} - -func (c *Client) ClusterNodes() ([]*cluster.NodeMetadata, error) { - r, err := c.request("/status/cluster/nodes") - if err != nil { - return nil, err - } - defer r.Close() - - var nodes []*cluster.NodeMetadata - if err := json.NewDecoder(r).Decode(&nodes); err != nil { - return nil, fmt.Errorf("decode response: %w", err) - } - return nodes, nil -} - -func (c *Client) ClusterNode(nodeID string) (*cluster.Node, error) { - r, err := c.request("/status/cluster/nodes/" + nodeID) - if err != nil { - return nil, err - } - defer r.Close() - - var node cluster.Node - if err := json.NewDecoder(r).Decode(&node); err != nil { - return nil, fmt.Errorf("decode response: %w", err) - } - return &node, nil -} - -func (c *Client) GossipNodes() ([]gossip.NodeMetadata, error) { - r, err := c.request("/status/gossip/nodes") - if err != nil { - return nil, err - } - defer r.Close() - - var members []gossip.NodeMetadata - if err := json.NewDecoder(r).Decode(&members); err != nil { - return nil, fmt.Errorf("decode response: %w", err) - } - return members, nil -} - -func (c *Client) GossipNode(memberID string) (*gossip.NodeState, error) { - r, err := c.request("/status/gossip/nodes/" + memberID) - if err != nil { - return nil, err - } - defer r.Close() - - var member gossip.NodeState - if err := json.NewDecoder(r).Decode(&member); err != nil { - return nil, fmt.Errorf("decode response: %w", err) - } - return &member, nil -} - -func (c *Client) Close() { - c.httpClient.CloseIdleConnections() -} - -func (c *Client) request(path string) (io.ReadCloser, error) { - url := new(url.URL) - *url = *c.url - - if c.forward != "" { - url.RawQuery = "forward=" + c.forward - } - - url.Path = fspath.Join(url.Path, path) - - req, err := http.NewRequest(http.MethodGet, url.String(), nil) - if err != nil { - return nil, fmt.Errorf("request: %w", err) - } - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("request: %w", err) - } - - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - - return nil, fmt.Errorf("request: bad status: %d", resp.StatusCode) - } - - return resp.Body, nil -} diff --git a/tests/cluster_test.go b/tests/cluster_test.go deleted file mode 100644 index c062356..0000000 --- a/tests/cluster_test.go +++ /dev/null @@ -1,71 +0,0 @@ -//go:build system - -package tests - -import ( - "context" - "net/url" - "sync" - "testing" - "time" - - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/server" - statusclient "github.com/andydunstall/piko/status/client" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCluster(t *testing.T) { - t.Run("discover", func(t *testing.T) { - var wg sync.WaitGroup - - server1Conf := defaultServerConfig() - server1, err := server.NewServer(server1Conf, log.NewNopLogger()) - require.NoError(t, err) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - wg.Add(1) - go func() { - defer wg.Done() - require.NoError(t, server1.Run(ctx)) - }() - - server2Conf := defaultServerConfig() - server2Conf.Cluster.Join = []string{server1Conf.Gossip.AdvertiseAddr} - server2, err := server.NewServer(server2Conf, log.NewNopLogger()) - require.NoError(t, err) - - wg.Add(1) - go func() { - defer wg.Done() - require.NoError(t, server2.Run(ctx)) - }() - - // Wait for each server to discover the other. - for _, addr := range []string{ - server1Conf.Admin.AdvertiseAddr, - server2Conf.Admin.AdvertiseAddr, - } { - for { - statusClient := statusclient.NewClient(&url.URL{ - Scheme: "http", - Host: addr, - }, "") - nodes, err := statusClient.ClusterNodes() - assert.NoError(t, err) - - if len(nodes) < 2 { - <-time.After(time.Millisecond * 10) - continue - } - break - } - } - - cancel() - wg.Wait() - }) -} diff --git a/tests/proxy_test.go b/tests/proxy_test.go deleted file mode 100644 index 66ffbf5..0000000 --- a/tests/proxy_test.go +++ /dev/null @@ -1,174 +0,0 @@ -//go:build system - -package tests - -import ( - "context" - "crypto/rand" - "net/http" - "net/http/httptest" - "net/url" - "testing" - "time" - - "github.com/andydunstall/piko/agent" - agentconfig "github.com/andydunstall/piko/agent/config" - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/server" - statusclient "github.com/andydunstall/piko/status/client" - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestProxy(t *testing.T) { - serverConf := defaultServerConfig() - server, err := server.NewServer(serverConf, log.NewNopLogger()) - require.NoError(t, err) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go func() { - require.NoError(t, server.Run(ctx)) - }() - - upstream := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) {}, - )) - defer upstream.Close() - - agentConf := defaultAgentConfig(serverConf.Upstream.AdvertiseAddr) - endpoint := agent.NewEndpoint( - "my-endpoint", - upstream.Listener.Addr().String(), - agentConf, - nil, - agent.NewMetrics(), - log.NewNopLogger(), - ) - go func() { - assert.NoError(t, endpoint.Run(ctx)) - }() - - // Wait for the agent to register the endpoint with Piko. - for { - statusClient := statusclient.NewClient(&url.URL{ - Scheme: "http", - Host: serverConf.Admin.AdvertiseAddr, - }, "") - endpoints, err := statusClient.ProxyEndpoints() - assert.NoError(t, err) - - if len(endpoints) == 0 { - <-time.After(time.Millisecond * 10) - continue - } - - _, ok := endpoints["my-endpoint"] - assert.True(t, ok) - break - } - - // Send a request to Piko which should be forwarded to the upstream server. - client := &http.Client{} - req, _ := http.NewRequest("GET", "http://"+serverConf.Proxy.AdvertiseAddr, nil) - req.Header.Set("x-piko-endpoint", "my-endpoint") - resp, err := client.Do(req) - assert.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) -} - -func TestProxy_Authenticated(t *testing.T) { - hsSecretKey := generateTestHSKey(t) - - serverConf := defaultServerConfig() - serverConf.Auth.TokenHMACSecretKey = string(hsSecretKey) - - server, err := server.NewServer(serverConf, log.NewNopLogger()) - require.NoError(t, err) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go func() { - require.NoError(t, server.Run(ctx)) - }() - - upstream := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) {}, - )) - defer upstream.Close() - - endpointClaims := jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - Issuer: "bar", - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, endpointClaims) - apiKey, err := token.SignedString([]byte(hsSecretKey)) - assert.NoError(t, err) - - agentConf := defaultAgentConfig(serverConf.Upstream.AdvertiseAddr) - agentConf.Auth.APIKey = apiKey - endpoint := agent.NewEndpoint( - "my-endpoint", - upstream.Listener.Addr().String(), - agentConf, - nil, - agent.NewMetrics(), - log.NewNopLogger(), - ) - go func() { - assert.NoError(t, endpoint.Run(ctx)) - }() - - // Wait for the agent to register the endpoint with Piko. - for { - statusClient := statusclient.NewClient(&url.URL{ - Scheme: "http", - Host: serverConf.Admin.AdvertiseAddr, - }, "") - endpoints, err := statusClient.ProxyEndpoints() - assert.NoError(t, err) - - if len(endpoints) == 0 { - <-time.After(time.Millisecond * 10) - continue - } - - _, ok := endpoints["my-endpoint"] - assert.True(t, ok) - break - } - - // Send a request to Piko which should be forwarded to the upstream server. - client := &http.Client{} - req, _ := http.NewRequest("GET", "http://"+serverConf.Proxy.AdvertiseAddr, nil) - req.Header.Set("x-piko-endpoint", "my-endpoint") - resp, err := client.Do(req) - assert.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) -} - -func defaultAgentConfig(serverAddr string) *agentconfig.Config { - return &agentconfig.Config{ - Server: agentconfig.ServerConfig{ - URL: "http://" + serverAddr, - HeartbeatInterval: time.Second, - HeartbeatTimeout: time.Second, - }, - Forwarder: agentconfig.ForwarderConfig{ - Timeout: time.Second, - }, - Admin: agentconfig.AdminConfig{ - BindAddr: "127.0.0.1:0", - }, - } -} - -func generateTestHSKey(t *testing.T) []byte { - b := make([]byte, 10) - _, err := rand.Read(b) - require.NoError(t, err) - return b -} diff --git a/tests/server_test.go b/tests/server_test.go deleted file mode 100644 index b670233..0000000 --- a/tests/server_test.go +++ /dev/null @@ -1,68 +0,0 @@ -//go:build system - -package tests - -import ( - "context" - "net/http" - "testing" - "time" - - "github.com/andydunstall/piko/pkg/gossip" - "github.com/andydunstall/piko/pkg/log" - "github.com/andydunstall/piko/server" - serverconfig "github.com/andydunstall/piko/server/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestServer_AdminAPI(t *testing.T) { - serverConf := defaultServerConfig() - server, err := server.NewServer(serverConf, log.NewNopLogger()) - require.NoError(t, err) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go func() { - require.NoError(t, server.Run(ctx)) - }() - - t.Run("health", func(t *testing.T) { - resp, err := http.Get( - "http://" + serverConf.Admin.AdvertiseAddr + "/health", - ) - assert.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - }) - - t.Run("metrics", func(t *testing.T) { - resp, err := http.Get( - "http://" + serverConf.Admin.AdvertiseAddr + "/metrics", - ) - assert.NoError(t, err) - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - }) -} - -// defaultServerConfig returns the default server configuration for local -// tests. -func defaultServerConfig() *serverconfig.Config { - return &serverconfig.Config{ - Proxy: serverconfig.ProxyConfig{ - BindAddr: "127.0.0.1:0", - GatewayTimeout: time.Second, - }, - Upstream: serverconfig.UpstreamConfig{ - BindAddr: "127.0.0.1:0", - }, - Admin: serverconfig.AdminConfig{ - BindAddr: "127.0.0.1:0", - }, - Gossip: gossip.Config{ - Interval: time.Millisecond * 10, - MaxPacketSize: 1400, - }, - } -} diff --git a/workload/upstream/upstream.go b/workload/upstream/upstream.go index 68f2ff5..1c27d92 100644 --- a/workload/upstream/upstream.go +++ b/workload/upstream/upstream.go @@ -2,57 +2,17 @@ package upstream import ( "context" - "errors" "fmt" "io" - "net" "net/http" - "sync" - "time" + "net/http/httptest" - "github.com/andydunstall/piko/agent" - agentconfig "github.com/andydunstall/piko/agent/config" + "github.com/andydunstall/piko/agent/client" + "github.com/andydunstall/piko/agent/config" + "github.com/andydunstall/piko/agent/reverseproxy" "github.com/andydunstall/piko/pkg/log" - "go.uber.org/zap" ) -type server struct { - ln net.Listener - server *http.Server -} - -func newServer() (*server, error) { - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return nil, fmt.Errorf("listen: %w", err) - } - - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - //nolint - io.Copy(w, r.Body) - }) - return &server{ - server: &http.Server{ - Addr: ln.Addr().String(), - Handler: mux, - }, - ln: ln, - }, nil -} - -func (s *server) Addr() string { - return s.ln.Addr().String() -} - -func (s *server) Serve() error { - return s.server.Serve(s.ln) -} - -func (s *server) Close() error { - return s.server.Close() -} - type Upstream struct { endpointID string serverURL string @@ -68,44 +28,29 @@ func NewUpstream(endpointID string, serverURL string, logger log.Logger) *Upstre } func (u *Upstream) Run(ctx context.Context) error { - server, err := newServer() - if err != nil { - return err - } + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + //nolint + io.Copy(w, r.Body) + })) defer server.Close() - var wg sync.WaitGroup + client := client.New(client.WithURL(u.serverURL)) - wg.Add(1) + ln, err := client.Listen(context.Background(), u.endpointID) + if err != nil { + return fmt.Errorf("listen: %s: %w", ln.EndpointID(), err) + } + defer ln.Close() + + proxy := reverseproxy.NewServer(config.ListenerConfig{ + EndpointID: u.endpointID, + Addr: server.Listener.Addr().String(), + }, nil, log.NewNopLogger()) go func() { - defer wg.Done() - if err := server.Serve(); err != nil && !errors.Is(err, http.ErrServerClosed) { - u.logger.Error("failed to serve upstream", zap.Error(err)) - } + _ = proxy.Serve(ln) }() - agentConf := agentConfig(u.serverURL) - endpoint := agent.NewEndpoint( - u.endpointID, server.Addr(), agentConf, nil, agent.NewMetrics(), log.NewNopLogger(), - ) - if err = endpoint.Run(ctx); err != nil { - return fmt.Errorf("endpoint: %w", err) - } + <-ctx.Done() + proxy.Shutdown(context.Background()) return nil } - -func agentConfig(serverURL string) *agentconfig.Config { - return &agentconfig.Config{ - Server: agentconfig.ServerConfig{ - URL: serverURL, - HeartbeatInterval: time.Second, - HeartbeatTimeout: time.Second, - }, - Forwarder: agentconfig.ForwarderConfig{ - Timeout: time.Second, - }, - Admin: agentconfig.AdminConfig{ - BindAddr: "127.0.0.1:0", - }, - } -}