diff --git a/api/api_test.go b/api/api_test.go index a4f753a4..d1dd6233 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -180,7 +180,7 @@ func TestGetProxies(t *testing.T) { Network: config.DefaultNetwork, Address: config.DefaultAddress, } - client := network.NewClient(context.TODO(), clientConfig, zerolog.Logger{}) + client := network.NewClient(context.TODO(), clientConfig, zerolog.Logger{}, nil) newPool := pool.NewPool(context.TODO(), 1) assert.Nil(t, newPool.Put(client.ID, client)) @@ -225,7 +225,7 @@ func TestGetServers(t *testing.T) { Network: config.DefaultNetwork, Address: config.DefaultAddress, } - client := network.NewClient(context.TODO(), clientConfig, zerolog.Logger{}) + client := network.NewClient(context.TODO(), clientConfig, zerolog.Logger{}, nil) newPool := pool.NewPool(context.TODO(), 1) assert.Nil(t, newPool.Put(client.ID, client)) diff --git a/cmd/run.go b/cmd/run.go index 4e234721..7a8c7c61 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -248,7 +248,7 @@ var runCmd = &cobra.Command{ ) // Load plugins and register their hooks. - pluginRegistry.LoadPlugins(runCtx, conf.Plugin.Plugins) + pluginRegistry.LoadPlugins(runCtx, conf.Plugin.Plugins, conf.Plugin.StartTimeout) // Start the metrics merger if enabled. var metricsMerger *metrics.Merger @@ -295,7 +295,7 @@ var runCmd = &cobra.Command{ logger.Info().Str("name", pluginId.Name).Msg("Reloading crashed plugin") pluginConfig := conf.Plugin.GetPlugins(pluginId.Name) if pluginConfig != nil { - pluginRegistry.LoadPlugins(runCtx, pluginConfig) + pluginRegistry.LoadPlugins(runCtx, pluginConfig, conf.Plugin.StartTimeout) } } else { logger.Trace().Str("name", pluginId.Name).Msg("Successfully pinged plugin") @@ -415,7 +415,15 @@ var runCmd = &cobra.Command{ // Check if the metrics server is already running before registering the handler. if _, err = http.Get(address); err != nil { //nolint:gosec - mux.Handle(metricsConfig.Path, gziphandler.GzipHandler(handler)) + // The timeout handler limits the nested handlers from running for too long. + mux.Handle( + metricsConfig.Path, + http.TimeoutHandler( + gziphandler.GzipHandler(handler), + metricsConfig.GetTimeout(), + "The request timed out while fetching the metrics", + ), + ) } else { logger.Warn().Msg("Metrics server is already running, consider changing the port") span.RecordError(err) @@ -426,9 +434,16 @@ var runCmd = &cobra.Command{ Addr: metricsConfig.Address, Handler: mux, ReadHeaderTimeout: metricsConfig.GetReadHeaderTimeout(), + ReadTimeout: metricsConfig.GetTimeout(), + WriteTimeout: metricsConfig.GetTimeout(), + IdleTimeout: metricsConfig.GetTimeout(), } - logger.Info().Str("address", address).Msg("Metrics are exposed") + logger.Info().Fields(map[string]interface{}{ + "address": address, + "timeout": metricsConfig.GetTimeout().String(), + "readHeaderTimeout": metricsConfig.GetReadHeaderTimeout().String(), + }).Msg("Metrics are exposed") if metricsConfig.CertFile != "" && metricsConfig.KeyFile != "" { // Set up TLS. @@ -507,11 +522,21 @@ var runCmd = &cobra.Command{ clients[name].ReceiveTimeout = clients[name].GetReceiveTimeout() clients[name].SendDeadline = clients[name].GetSendDeadline() clients[name].ReceiveChunkSize = clients[name].GetReceiveChunkSize() + clients[name].DialTimeout = clients[name].GetDialTimeout() // Add clients to the pool. for i := 0; i < cfg.GetSize(); i++ { clientConfig := clients[name] - client := network.NewClient(runCtx, clientConfig, logger) + client := network.NewClient( + runCtx, clientConfig, logger, + network.NewRetry( + clientConfig.Retries, + clientConfig.GetBackoff(), + clientConfig.BackoffMultiplier, + clientConfig.DisableBackoffCaps, + loggers[name], + ), + ) if client != nil { eventOptions := trace.WithAttributes( @@ -522,10 +547,15 @@ var runCmd = &cobra.Command{ attribute.String("receiveDeadline", client.ReceiveDeadline.String()), attribute.String("receiveTimeout", client.ReceiveTimeout.String()), attribute.String("sendDeadline", client.SendDeadline.String()), + attribute.String("dialTimeout", client.DialTimeout.String()), attribute.Bool("tcpKeepAlive", client.TCPKeepAlive), attribute.String("tcpKeepAlivePeriod", client.TCPKeepAlivePeriod.String()), attribute.String("localAddress", client.LocalAddr()), attribute.String("remoteAddress", client.RemoteAddr()), + attribute.Int("retries", clientConfig.Retries), + attribute.String("backoff", clientConfig.GetBackoff().String()), + attribute.Float64("backoffMultiplier", clientConfig.BackoffMultiplier), + attribute.Bool("disableBackoffCaps", clientConfig.DisableBackoffCaps), ) if client.ID != "" { eventOptions = trace.WithAttributes( @@ -547,8 +577,15 @@ var runCmd = &cobra.Command{ "receiveDeadline": client.ReceiveDeadline.String(), "receiveTimeout": client.ReceiveTimeout.String(), "sendDeadline": client.SendDeadline.String(), + "dialTimeout": client.DialTimeout.String(), "tcpKeepAlive": client.TCPKeepAlive, "tcpKeepAlivePeriod": client.TCPKeepAlivePeriod.String(), + "localAddress": client.LocalAddr(), + "remoteAddress": client.RemoteAddr(), + "retries": clientConfig.Retries, + "backoff": clientConfig.GetBackoff().String(), + "backoffMultiplier": clientConfig.BackoffMultiplier, + "disableBackoffCaps": clientConfig.DisableBackoffCaps, } _, err := pluginRegistry.Run( pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT) diff --git a/config/config.go b/config/config.go index c21d0bec..e98b9109 100644 --- a/config/config.go +++ b/config/config.go @@ -116,6 +116,11 @@ func (c *Config) LoadDefaults(ctx context.Context) { ReceiveDeadline: DefaultReceiveDeadline, ReceiveTimeout: DefaultReceiveTimeout, SendDeadline: DefaultSendDeadline, + DialTimeout: DefaultDialTimeout, + Retries: DefaultRetries, + Backoff: DefaultBackoff, + BackoffMultiplier: DefaultBackoffMultiplier, + DisableBackoffCaps: DefaultDisableBackoffCaps, } defaultPool := Pool{ @@ -210,6 +215,7 @@ func (c *Config) LoadDefaults(ctx context.Context) { HealthCheckPeriod: DefaultPluginHealthCheckPeriod, ReloadOnCrash: true, Timeout: DefaultPluginTimeout, + StartTimeout: DefaultPluginStartTimeout, } if c.GlobalKoanf != nil { diff --git a/config/constants.go b/config/constants.go index 3693dacb..224885b9 100644 --- a/config/constants.go +++ b/config/constants.go @@ -92,6 +92,7 @@ const ( DefaultMetricsMergerPeriod = 5 * time.Second DefaultPluginHealthCheckPeriod = 5 * time.Second DefaultPluginTimeout = 30 * time.Second + DefaultPluginStartTimeout = 1 * time.Minute // Client constants. DefaultNetwork = "tcp" @@ -102,6 +103,11 @@ const ( DefaultTCPKeepAlivePeriod = 30 * time.Second DefaultTCPKeepAlive = false DefaultReceiveTimeout = 0 + DefaultDialTimeout = 60 * time.Second + DefaultRetries = 3 + DefaultBackoff = 1 * time.Second + DefaultBackoffMultiplier = 2.0 + DefaultDisableBackoffCaps = false // Pool constants. EmptyPoolCapacity = 0 diff --git a/config/getters.go b/config/getters.go index 2a9f087b..29dabd65 100644 --- a/config/getters.go +++ b/config/getters.go @@ -106,7 +106,7 @@ func (p PluginConfig) GetTerminationPolicy() TerminationPolicy { // GetTCPKeepAlivePeriod returns the TCP keep alive period from config file or default value. func (c Client) GetTCPKeepAlivePeriod() time.Duration { - if c.TCPKeepAlivePeriod <= 0 { + if c.TCPKeepAlivePeriod < 0 { return DefaultTCPKeepAlivePeriod } return c.TCPKeepAlivePeriod @@ -114,7 +114,7 @@ func (c Client) GetTCPKeepAlivePeriod() time.Duration { // GetReceiveDeadline returns the receive deadline from config file or default value. func (c Client) GetReceiveDeadline() time.Duration { - if c.ReceiveDeadline <= 0 { + if c.ReceiveDeadline < 0 { return DefaultReceiveDeadline } return c.ReceiveDeadline @@ -122,7 +122,7 @@ func (c Client) GetReceiveDeadline() time.Duration { // GetReceiveTimeout returns the receive timeout from config file or default value. func (c Client) GetReceiveTimeout() time.Duration { - if c.ReceiveTimeout <= 0 { + if c.ReceiveTimeout < 0 { return DefaultReceiveTimeout } return c.ReceiveTimeout @@ -130,7 +130,7 @@ func (c Client) GetReceiveTimeout() time.Duration { // GetSendDeadline returns the send deadline from config file or default value. func (c Client) GetSendDeadline() time.Duration { - if c.SendDeadline <= 0 { + if c.SendDeadline < 0 { return DefaultSendDeadline } return c.SendDeadline @@ -144,6 +144,22 @@ func (c Client) GetReceiveChunkSize() int { return c.ReceiveChunkSize } +// GetDialTimeout returns the dial timeout from config file or default value. +func (c Client) GetDialTimeout() time.Duration { + if c.DialTimeout < 0 { + return DefaultDialTimeout + } + return c.DialTimeout +} + +// GetBackoff returns the backoff from config file or default value. +func (c Client) GetBackoff() time.Duration { + if c.Backoff < 0 { + return DefaultBackoff + } + return c.Backoff +} + // GetHealthCheckPeriod returns the health check period from config file or default value. func (pr Proxy) GetHealthCheckPeriod() time.Duration { if pr.HealthCheckPeriod <= 0 { @@ -154,7 +170,7 @@ func (pr Proxy) GetHealthCheckPeriod() time.Duration { // GetTickInterval returns the tick interval from config file or default value. func (s Server) GetTickInterval() time.Duration { - if s.TickInterval <= 0 { + if s.TickInterval < 0 { return DefaultTickInterval } return s.TickInterval @@ -247,20 +263,23 @@ func GetDefaultConfigFilePath(filename string) string { return filepath.Join("./", filename) } +// GetReadHeaderTimeout returns the read header timeout from config file or default value. func (m Metrics) GetReadHeaderTimeout() time.Duration { - if m.ReadHeaderTimeout <= 0 { + if m.ReadHeaderTimeout < 0 { return DefaultReadHeaderTimeout } return m.ReadHeaderTimeout } +// GetTimeout returns the metrics server timeout from config file or default value. func (m Metrics) GetTimeout() time.Duration { - if m.Timeout <= 0 { + if m.Timeout < 0 { return DefaultMetricsServerTimeout } return m.Timeout } +// Filter returns a filtered global config based on the group name. func (gc GlobalConfig) Filter(groupName string) *GlobalConfig { if _, ok := gc.Servers[groupName]; !ok { return nil diff --git a/config/getters_test.go b/config/getters_test.go index 94a1af46..8812f55f 100644 --- a/config/getters_test.go +++ b/config/getters_test.go @@ -36,7 +36,7 @@ func TestGetTerminationPolicy(t *testing.T) { // TestGetTCPKeepAlivePeriod tests the GetTCPKeepAlivePeriod function. func TestGetTCPKeepAlivePeriod(t *testing.T) { client := Client{} - assert.Equal(t, DefaultTCPKeepAlivePeriod, client.GetTCPKeepAlivePeriod()) + assert.Equal(t, client.GetTCPKeepAlivePeriod(), time.Duration(0)) } // TestGetReceiveDeadline tests the GetReceiveDeadline function. @@ -72,7 +72,7 @@ func TestGetHealthCheckPeriod(t *testing.T) { // TestGetTickInterval tests the GetTickInterval function. func TestGetTickInterval(t *testing.T) { server := Server{} - assert.Equal(t, DefaultTickInterval, server.GetTickInterval()) + assert.Equal(t, server.GetTickInterval(), time.Duration(0)) } // TestGetSize tests the GetSize function. @@ -120,13 +120,13 @@ func TestGetDefaultConfigFilePath(t *testing.T) { // TestGetReadTimeout tests the GetReadTimeout function. func TestGetReadHeaderTimeout(t *testing.T) { metrics := Metrics{} - assert.Equal(t, DefaultReadHeaderTimeout, metrics.GetReadHeaderTimeout()) + assert.Equal(t, metrics.GetReadHeaderTimeout(), time.Duration(0)) } // TestGetTimeout tests the GetTimeout function of the metrics server. func TestGetTimeout(t *testing.T) { metrics := Metrics{} - assert.Equal(t, DefaultMetricsServerTimeout, metrics.GetTimeout()) + assert.Equal(t, metrics.GetTimeout(), time.Duration(0)) } // TestFilter tests the Filter function. diff --git a/config/types.go b/config/types.go index 23a9e6bd..5e7e5780 100644 --- a/config/types.go +++ b/config/types.go @@ -24,6 +24,7 @@ type PluginConfig struct { HealthCheckPeriod time.Duration `json:"healthCheckPeriod" jsonschema:"oneof_type=string;integer"` ReloadOnCrash bool `json:"reloadOnCrash"` Timeout time.Duration `json:"timeout" jsonschema:"oneof_type=string;integer"` + StartTimeout time.Duration `json:"startTimeout" jsonschema:"oneof_type=string;integer"` Plugins []Plugin `json:"plugins"` } @@ -36,6 +37,11 @@ type Client struct { ReceiveDeadline time.Duration `json:"receiveDeadline" jsonschema:"oneof_type=string;integer"` ReceiveTimeout time.Duration `json:"receiveTimeout" jsonschema:"oneof_type=string;integer"` SendDeadline time.Duration `json:"sendDeadline" jsonschema:"oneof_type=string;integer"` + DialTimeout time.Duration `json:"dialTimeout" jsonschema:"oneof_type=string;integer"` + Retries int `json:"retries"` + Backoff time.Duration `json:"backoff" jsonschema:"oneof_type=string;integer"` + BackoffMultiplier float64 `json:"backoffMultiplier"` + DisableBackoffCaps bool `json:"disableBackoffCaps"` } type Logger struct { diff --git a/gatewayd.yaml b/gatewayd.yaml index 93a328c2..40fd6ea6 100644 --- a/gatewayd.yaml +++ b/gatewayd.yaml @@ -7,9 +7,10 @@ loggers: noColor: False timeFormat: "unix" # unixms, unixmicro and unixnano consoleTimeFormat: "RFC3339" # Go time format string - # If output is file, the following fields are used. + # If the output contains "file", the following fields are used: fileName: "gatewayd.log" maxSize: 500 # MB + # If maxBackups and maxAge are both 0, no old log files will be deleted. maxBackups: 5 maxAge: 30 # days compress: True @@ -39,6 +40,12 @@ clients: receiveDeadline: 0s # duration, 0ms/0s means no deadline receiveTimeout: 0s # duration, 0ms/0s means no timeout sendDeadline: 0s # duration, 0ms/0s means no deadline + dialTimeout: 60s # duration + # Retry configuration + retries: 3 # 0 means no retry + backoff: 1s # duration + backoffMultiplier: 2.0 # 0 means no backoff + disableBackoffCaps: false pools: default: diff --git a/gatewayd_plugins.yaml b/gatewayd_plugins.yaml index 655b7335..33bb7374 100644 --- a/gatewayd_plugins.yaml +++ b/gatewayd_plugins.yaml @@ -59,6 +59,9 @@ reloadOnCrash: True # The timeout controls how long to wait for a plugin to respond to a request before timing out. timeout: 30s +# The start timeout controls how long to wait for a plugin to start before timing out. +startTimeout: 1m + # The plugin configuration is a list of plugins to load. Each plugin is defined by a name, # a path to the plugin's executable, and a list of arguments to pass to the plugin. The # plugin's executable is expected to be a Go plugin that implements the GatewayD plugin diff --git a/metrics/merger.go b/metrics/merger.go index 7663a6a1..1e87634e 100644 --- a/metrics/merger.go +++ b/metrics/merger.go @@ -259,7 +259,7 @@ func (m *Merger) Start() { m.scheduler.StartAsync() m.Logger.Info().Fields( map[string]interface{}{ - "startDelay": startDelay, + "startDelay": startDelay.Format(time.RFC3339), "metricsMergerPeriod": m.MetricsMergerPeriod.String(), }, ).Msg("Started the metrics merger scheduler") diff --git a/network/client.go b/network/client.go index 97bd0382..3f07b10b 100644 --- a/network/client.go +++ b/network/client.go @@ -32,6 +32,7 @@ type Client struct { ctx context.Context //nolint:containedctx connected atomic.Bool mu sync.Mutex + retry IRetry TCPKeepAlive bool TCPKeepAlivePeriod time.Duration @@ -39,6 +40,7 @@ type Client struct { ReceiveDeadline time.Duration SendDeadline time.Duration ReceiveTimeout time.Duration + DialTimeout time.Duration ID string Network string // tcp/udp/unix Address string @@ -47,7 +49,9 @@ type Client struct { var _ IClient = (*Client)(nil) // NewClient creates a new client. -func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog.Logger) *Client { +func NewClient( + ctx context.Context, clientConfig *config.Client, logger zerolog.Logger, retry *Retry, +) *Client { clientCtx, span := otel.Tracer(config.TracerName).Start(ctx, "NewClient") defer span.End() @@ -69,10 +73,12 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog. // Create a resolved client. client = Client{ - ctx: clientCtx, - mu: sync.Mutex{}, - Network: clientConfig.Network, - Address: addr, + ctx: clientCtx, + mu: sync.Mutex{}, + retry: retry, + Network: clientConfig.Network, + Address: addr, + DialTimeout: clientConfig.DialTimeout, } // Fall back to the original network and address if the address can't be resolved. @@ -83,8 +89,24 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog. } } - // Create a new connection. - conn, origErr := net.Dial(client.Network, client.Address) + var origErr error + // Create a new connection and retry a few times if needed. + //nolint:wrapcheck + if conn, err := client.retry.Retry(func() (any, error) { + if client.DialTimeout > 0 { + return net.DialTimeout(client.Network, client.Address, client.DialTimeout) + } else { + return net.Dial(client.Network, client.Address) + } + }); err != nil { + origErr = err + } else { + if netConn, ok := conn.(net.Conn); ok { + client.conn = netConn + } else { + origErr = fmt.Errorf("unexpected connection type: %T", conn) + } + } if origErr != nil { err := gerr.ErrClientConnectionFailed.Wrap(origErr) logger.Error().Err(err).Msg("Failed to create a new connection") @@ -92,7 +114,6 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog. return nil } - client.conn = conn client.connected.Store(true) // Set the TCP keep alive. @@ -141,7 +162,11 @@ func NewClient(ctx context.Context, clientConfig *config.Client, logger zerolog. logger.Trace().Str("address", client.Address).Msg("New client created") client.ID = GetID( - conn.LocalAddr().Network(), conn.LocalAddr().String(), config.DefaultSeed, logger) + client.conn.LocalAddr().Network(), + client.conn.LocalAddr().String(), + config.DefaultSeed, + logger, + ) metrics.ServerConnections.Inc() @@ -182,6 +207,8 @@ func (c *Client) Send(data []byte) (int, *gerr.GatewayDError) { }, ).Msg("Sent data to server") + span.AddEvent("Sent data to server") + return sent, nil } @@ -222,6 +249,9 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) { break } } + + span.AddEvent("Received data from server") + return received, buffer.Bytes(), nil } @@ -245,19 +275,40 @@ func (c *Client) Reconnect() error { c.Address = address c.Network = network - conn, err := net.Dial(c.Network, c.Address) - if err != nil { - c.logger.Error().Err(err).Msg("Failed to reconnect") - span.RecordError(err) - return gerr.ErrClientConnectionFailed.Wrap(err) + var origErr error + // Create a new connection and retry a few times if needed. + //nolint:wrapcheck + if conn, err := c.retry.Retry(func() (any, error) { + if c.DialTimeout > 0 { + return net.DialTimeout(c.Network, c.Address, c.DialTimeout) + } else { + return net.Dial(c.Network, c.Address) + } + }); err != nil { + origErr = err + } else { + if netConn, ok := conn.(net.Conn); ok { + c.conn = netConn + } else { + origErr = fmt.Errorf("unexpected connection type: %T", conn) + } + } + if origErr != nil { + c.logger.Error().Err(origErr).Msg("Failed to reconnect") + span.RecordError(origErr) + return gerr.ErrClientConnectionFailed.Wrap(origErr) } - c.conn = conn c.ID = GetID( - conn.LocalAddr().Network(), conn.LocalAddr().String(), config.DefaultSeed, c.logger) + c.conn.LocalAddr().Network(), + c.conn.LocalAddr().String(), + config.DefaultSeed, + c.logger, + ) c.connected.Store(true) c.logger.Debug().Str("address", c.Address).Msg("Reconnected to server") metrics.ServerConnections.Inc() + span.AddEvent("Reconnected to server") return nil } @@ -294,6 +345,8 @@ func (c *Client) Close() { c.Network = "" metrics.ServerConnections.Dec() + + span.AddEvent("Closed connection to server") } // IsConnected checks if the client is still connected to the server. diff --git a/network/client_test.go b/network/client_test.go index 5cb73857..8b99d1f9 100644 --- a/network/client_test.go +++ b/network/client_test.go @@ -33,10 +33,12 @@ func CreateNewClient(t *testing.T) *Client { ReceiveDeadline: config.DefaultReceiveDeadline, ReceiveTimeout: config.DefaultReceiveTimeout, SendDeadline: config.DefaultSendDeadline, + DialTimeout: config.DefaultDialTimeout, TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, }, - logger) + logger, + nil) return client } @@ -144,9 +146,10 @@ func BenchmarkNewClient(b *testing.B) { ReceiveChunkSize: config.DefaultChunkSize, ReceiveDeadline: config.DefaultReceiveDeadline, SendDeadline: config.DefaultSendDeadline, + DialTimeout: config.DefaultDialTimeout, TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, - }, logger) + }, logger, nil) client.Close() } } @@ -168,10 +171,11 @@ func BenchmarkSend(b *testing.B) { ReceiveChunkSize: config.DefaultChunkSize, ReceiveDeadline: config.DefaultReceiveDeadline, SendDeadline: config.DefaultSendDeadline, + DialTimeout: config.DefaultDialTimeout, TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, }, - logger) + logger, nil) defer client.Close() packet := CreatePgStartupPacket() @@ -198,10 +202,11 @@ func BenchmarkReceive(b *testing.B) { ReceiveDeadline: config.DefaultReceiveDeadline, ReceiveTimeout: 1 * time.Millisecond, SendDeadline: config.DefaultSendDeadline, + DialTimeout: config.DefaultDialTimeout, TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, }, - logger) + logger, nil) defer client.Close() packet := CreatePgStartupPacket() @@ -228,10 +233,11 @@ func BenchmarkIsConnected(b *testing.B) { ReceiveChunkSize: config.DefaultChunkSize, ReceiveDeadline: config.DefaultReceiveDeadline, SendDeadline: config.DefaultSendDeadline, + DialTimeout: config.DefaultDialTimeout, TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, }, - logger) + logger, nil) defer client.Close() for i := 0; i < b.N; i++ { diff --git a/network/proxy.go b/network/proxy.go index adb6f028..bec9ac92 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -89,7 +89,16 @@ func NewProxy( proxy.availableConnections.Remove(client.ID) client.Close() // Create a new client. - client = NewClient(proxyCtx, proxy.ClientConfig, proxy.logger) + client = NewClient( + proxyCtx, proxy.ClientConfig, proxy.logger, + NewRetry( + proxy.ClientConfig.Retries, + proxy.ClientConfig.GetBackoff(), + proxy.ClientConfig.BackoffMultiplier, + proxy.ClientConfig.DisableBackoffCaps, + proxy.logger, + ), + ) if client != nil && client.ID != "" { if err := proxy.availableConnections.Put(client.ID, client); err != nil { proxy.logger.Err(err).Msg("Failed to update the client connection") @@ -146,7 +155,16 @@ func (pr *Proxy) Connect(conn *ConnWrapper) *gerr.GatewayDError { // Pool is exhausted or is elastic. if pr.Elastic { // Create a new client. - client = NewClient(pr.ctx, pr.ClientConfig, pr.logger) + client = NewClient( + pr.ctx, pr.ClientConfig, pr.logger, + NewRetry( + pr.ClientConfig.Retries, + pr.ClientConfig.GetBackoff(), + pr.ClientConfig.BackoffMultiplier, + pr.ClientConfig.DisableBackoffCaps, + pr.logger, + ), + ) span.AddEvent("Created a new client connection") pr.logger.Debug().Str("id", client.ID[:7]).Msg("Reused the client connection") } else { @@ -721,6 +739,9 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD "remote": RemoteAddr(conn), }, ).Msg("Received data from client") + + span.AddEvent("Received data from client") + metrics.BytesReceivedFromClient.Observe(float64(length)) metrics.TotalTrafficBytes.Observe(float64(length)) @@ -752,6 +773,8 @@ func (pr *Proxy) sendTrafficToServer(client *Client, request []byte) (int, *gerr }, ).Msg("Sent data to database") + span.AddEvent("Sent data to database") + metrics.BytesSentToServer.Observe(float64(sent)) metrics.TotalTrafficBytes.Observe(float64(sent)) @@ -779,6 +802,8 @@ func (pr *Proxy) receiveTrafficFromServer(client *Client) (int, []byte, *gerr.Ga pr.logger.Debug().Fields(fields).Msg("Received data from database") + span.AddEvent("Received data from database") + metrics.BytesReceivedFromServer.Observe(float64(received)) metrics.TotalTrafficBytes.Observe(float64(received)) diff --git a/network/proxy_test.go b/network/proxy_test.go index 72b2ec3f..295e87cd 100644 --- a/network/proxy_test.go +++ b/network/proxy_test.go @@ -34,10 +34,12 @@ func TestNewProxy(t *testing.T) { ReceiveChunkSize: config.DefaultChunkSize, ReceiveDeadline: config.DefaultReceiveDeadline, SendDeadline: config.DefaultSendDeadline, + DialTimeout: config.DefaultDialTimeout, TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, }, - logger) + logger, + nil) err := newPool.Put(client.ID, client) assert.Nil(t, err) @@ -111,6 +113,7 @@ func TestNewProxyElastic(t *testing.T) { ReceiveChunkSize: config.DefaultChunkSize, ReceiveDeadline: config.DefaultReceiveDeadline, SendDeadline: config.DefaultSendDeadline, + DialTimeout: config.DefaultDialTimeout, TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, }, @@ -198,6 +201,7 @@ func BenchmarkNewProxyElastic(b *testing.B) { ReceiveChunkSize: config.DefaultChunkSize, ReceiveDeadline: config.DefaultReceiveDeadline, SendDeadline: config.DefaultSendDeadline, + DialTimeout: config.DefaultDialTimeout, TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, }, @@ -229,7 +233,7 @@ func BenchmarkProxyConnectDisconnect(b *testing.B) { TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, } - newPool.Put("client", NewClient(context.Background(), &clientConfig, logger)) //nolint:errcheck + newPool.Put("client", NewClient(context.Background(), &clientConfig, logger, nil)) //nolint:errcheck // Create a proxy with a fixed buffer newPool proxy := NewProxy( @@ -283,7 +287,7 @@ func BenchmarkProxyPassThrough(b *testing.B) { TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, } - newPool.Put("client", NewClient(context.Background(), &clientConfig, logger)) //nolint:errcheck + newPool.Put("client", NewClient(context.Background(), &clientConfig, logger, nil)) //nolint:errcheck // Create a proxy with a fixed buffer newPool proxy := NewProxy( @@ -341,7 +345,7 @@ func BenchmarkProxyIsHealthyAndIsExhausted(b *testing.B) { TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, } - client := NewClient(context.Background(), &clientConfig, logger) + client := NewClient(context.Background(), &clientConfig, logger, nil) newPool.Put("client", client) //nolint:errcheck // Create a proxy with a fixed buffer newPool @@ -398,7 +402,7 @@ func BenchmarkProxyAvailableAndBusyConnections(b *testing.B) { TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, } - client := NewClient(context.Background(), &clientConfig, logger) + client := NewClient(context.Background(), &clientConfig, logger, nil) newPool.Put("client", client) //nolint:errcheck // Create a proxy with a fixed buffer newPool diff --git a/network/retry.go b/network/retry.go new file mode 100644 index 00000000..3f720df3 --- /dev/null +++ b/network/retry.go @@ -0,0 +1,127 @@ +package network + +import ( + "errors" + "math" + "time" + + "github.com/rs/zerolog" +) + +const ( + BackoffMultiplierCap = 10 + BackoffDurationCap = time.Minute +) + +type RetryCallback func() (any, error) + +type IRetry interface { + Retry(_ RetryCallback) (any, error) +} + +type Retry struct { + logger zerolog.Logger + Retries int + Backoff time.Duration + BackoffMultiplier float64 + DisableBackoffCaps bool +} + +var _ IRetry = (*Retry)(nil) + +// Retry runs the callback function and retries it if it fails. +// It'll wait for the duration of the backoff between retries. +func (r *Retry) Retry(callback RetryCallback) (any, error) { + var ( + object any + err error + retry int + ) + + if callback == nil { + return nil, errors.New("callback is nil") + } + + if r == nil && callback != nil { + return callback() + } + + // The first attempt counts as a retry. + for ; retry <= r.Retries; retry++ { + // Wait for the backoff duration before retrying. The backoff duration is + // calculated by multiplying the backoff duration by the backoff multiplier + // raised to the power of the number of retries. For example, if the backoff + // duration is 1 second and the backoff multiplier is 2, the backoff duration + // will be 1 second, 2 seconds, 4 seconds, 8 seconds, etc. The backoff duration + // is capped at 1 minute and the backoff multiplier is capped at 10, so the + // backoff duration will be 1 minute after 6 retries. The backoff multiplier + // is capped at 10 to prevent the backoff duration from growing too quickly, + // unless the backoff caps are disabled. + // Example: 1 second * 2 ^ 1 = 2 seconds + // 1 second * 2 ^ 2 = 4 seconds + // 1 second * 2 ^ 3 = 8 seconds + // 1 second * 2 ^ 4 = 16 seconds + // 1 second * 2 ^ 5 = 32 seconds + // 1 second * 2 ^ 6 = 1 minute + // 1 second * 2 ^ 7 = 1 minute (capped) + // 1 second * 2 ^ 8 = 1 minute (capped) + // 1 second * 2 ^ 9 = 1 minute (capped) + // 1 second * 2 ^ 10 = 1 minute (capped) + backoffDuration := r.Backoff * time.Duration( + math.Pow(r.BackoffMultiplier, float64(retry)), + ) + + if !r.DisableBackoffCaps && backoffDuration > BackoffDurationCap { + backoffDuration = BackoffDurationCap + } + + if retry > 0 { + r.logger.Debug().Fields( + map[string]interface{}{ + "retry": retry, + "delay": backoffDuration.String(), + }, + ).Msg("Trying to run callback again") + } else { + r.logger.Trace().Msg("First attempt to run callback") + } + + // Try and retry the callback. + object, err = callback() + if err == nil { + return object, nil + } + + time.Sleep(backoffDuration) + } + + r.logger.Error().Err(err).Msgf("Failed to run callback after %d retries", retry) + + return nil, err +} + +func NewRetry( + retries int, + backoff time.Duration, + backoffMultiplier float64, + disableBackoffCaps bool, + logger zerolog.Logger, +) *Retry { + retry := Retry{ + Retries: retries, + Backoff: backoff, + BackoffMultiplier: backoffMultiplier, + DisableBackoffCaps: disableBackoffCaps, + logger: logger, + } + + if retry.Retries == 0 { + retry.Retries = 1 + } + + if !retry.DisableBackoffCaps && retry.BackoffMultiplier > BackoffMultiplierCap { + retry.BackoffMultiplier = BackoffMultiplierCap + } + + return &retry +} diff --git a/network/retry_test.go b/network/retry_test.go new file mode 100644 index 00000000..a3889042 --- /dev/null +++ b/network/retry_test.go @@ -0,0 +1,77 @@ +package network + +import ( + "context" + "net" + "testing" + "time" + + "github.com/gatewayd-io/gatewayd/config" + "github.com/gatewayd-io/gatewayd/logging" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" +) + +func TestRetry(t *testing.T) { + logger := logging.NewLogger(context.Background(), logging.LoggerConfig{ + Output: []config.LogOutput{config.Console}, + TimeFormat: zerolog.TimeFormatUnix, + ConsoleTimeFormat: time.RFC3339, + Level: zerolog.DebugLevel, + NoColor: true, + }) + + t.Run("DialTimeout", func(t *testing.T) { + t.Run("nil", func(t *testing.T) { + // Nil retry should just dial the connection once. + var retry *Retry + _, err := retry.Retry(nil) + assert.Error(t, err) + assert.ErrorContains(t, err, "callback is nil") + }) + t.Run("retry without timeout", func(t *testing.T) { + retry := NewRetry(0, 0, 0, false, logger) + assert.Equal(t, 1, retry.Retries) + assert.Equal(t, time.Duration(0), retry.Backoff) + assert.Equal(t, float64(0), retry.BackoffMultiplier) + assert.False(t, retry.DisableBackoffCaps) + + conn, err := retry.Retry(func() (any, error) { + return net.Dial("tcp", "localhost:5432") //nolint: wrapcheck + }) + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.IsType(t, &net.TCPConn{}, conn) + if tcpConn, ok := conn.(*net.TCPConn); ok { + tcpConn.Close() + } else { + t.Errorf("Unexpected connection type: %T", conn) + } + }) + t.Run("retry with timeout", func(t *testing.T) { + retry := NewRetry( + config.DefaultRetries, + config.DefaultBackoff, + config.DefaultBackoffMultiplier, + config.DefaultDisableBackoffCaps, + logger, + ) + assert.Equal(t, config.DefaultRetries, retry.Retries) + assert.Equal(t, config.DefaultBackoff, retry.Backoff) + assert.Equal(t, config.DefaultBackoffMultiplier, retry.BackoffMultiplier) + assert.False(t, retry.DisableBackoffCaps) + + conn, err := retry.Retry(func() (any, error) { + return net.DialTimeout("tcp", "localhost:5432", config.DefaultDialTimeout) //nolint: wrapcheck + }) + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.IsType(t, &net.TCPConn{}, conn) + if tcpConn, ok := conn.(*net.TCPConn); ok { + tcpConn.Close() + } else { + t.Errorf("Unexpected connection type: %T", conn) + } + }) + }) +} diff --git a/network/server_test.go b/network/server_test.go index e2cb24fa..62e2c10a 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -65,13 +65,13 @@ func TestRunServer(t *testing.T) { // Create a connection newPool. newPool := pool.NewPool(context.Background(), 3) - client1 := NewClient(context.Background(), &clientConfig, logger) + client1 := NewClient(context.Background(), &clientConfig, logger, nil) err := newPool.Put(client1.ID, client1) assert.Nil(t, err) - client2 := NewClient(context.Background(), &clientConfig, logger) + client2 := NewClient(context.Background(), &clientConfig, logger, nil) err = newPool.Put(client2.ID, client2) assert.Nil(t, err) - client3 := NewClient(context.Background(), &clientConfig, logger) + client3 := NewClient(context.Background(), &clientConfig, logger, nil) err = newPool.Put(client3.ID, client3) assert.Nil(t, err) @@ -134,10 +134,12 @@ func TestRunServer(t *testing.T) { ReceiveChunkSize: config.DefaultChunkSize, ReceiveDeadline: config.DefaultReceiveDeadline, SendDeadline: config.DefaultSendDeadline, + DialTimeout: config.DefaultDialTimeout, TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, }, - logger) + logger, + nil) assert.NotNil(t, client) sent, err := client.Send(CreatePgStartupPacket()) diff --git a/network/utils_test.go b/network/utils_test.go index 2de94cd3..9a13324e 100644 --- a/network/utils_test.go +++ b/network/utils_test.go @@ -165,7 +165,7 @@ func BenchmarkTrafficData(b *testing.B) { TCPKeepAlive: false, TCPKeepAlivePeriod: time.Second * 10, ReceiveChunkSize: 1024, - }, logger) + }, logger, nil) fields := []Field{ { Name: "test", diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index e42ffcde..ff4127e6 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "encoding/hex" "sort" + "time" "github.com/Masterminds/semver/v3" sdkPlugin "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin" @@ -44,7 +45,7 @@ type IRegistry interface { ForEach(f func(sdkPlugin.Identifier, *Plugin)) Remove(pluginID sdkPlugin.Identifier) Shutdown() - LoadPlugins(ctx context.Context, plugins []config.Plugin) + LoadPlugins(ctx context.Context, plugins []config.Plugin, startTimeout time.Duration) RegisterHooks(ctx context.Context, pluginID sdkPlugin.Identifier) // Hook management @@ -62,6 +63,7 @@ type Registry struct { Verification config.VerificationPolicy Acceptance config.AcceptancePolicy Termination config.TerminationPolicy + StartTimeout time.Duration } var _ IRegistry = (*Registry)(nil) @@ -383,7 +385,9 @@ func (reg *Registry) Run( } // LoadPlugins loads plugins from the config file. -func (reg *Registry) LoadPlugins(ctx context.Context, plugins []config.Plugin) { +func (reg *Registry) LoadPlugins( + ctx context.Context, plugins []config.Plugin, startTimeout time.Duration, +) { // TODO: Append built-in plugins to the list of plugins // Built-in plugins are plugins that are compiled and shipped with the gatewayd binary. ctx, span := otel.Tracer("").Start(ctx, "Load plugins") @@ -484,6 +488,7 @@ func (reg *Registry) LoadPlugins(ctx context.Context, plugins []config.Plugin) { MinPort: config.DefaultMinPort, MaxPort: config.DefaultMaxPort, AutoMTLS: true, + StartTimeout: startTimeout, }, )