diff --git a/pkg/scalers/azure_eventhub_scaler.go b/pkg/scalers/azure_eventhub_scaler.go index 67ad466fa4a..a0ca61cabcb 100644 --- a/pkg/scalers/azure_eventhub_scaler.go +++ b/pkg/scalers/azure_eventhub_scaler.go @@ -274,5 +274,13 @@ func getTotalLagRelatedToPartitionAmount(unprocessedEventsCount int64, partition // Close closes Azure Event Hub Scaler func (scaler *azureEventHubScaler) Close() error { + if scaler.client != nil { + err := scaler.client.Close(nil) + if err != nil { + eventhubLog.Error(err, "error closing azure event hub client") + return err + } + } + return nil } diff --git a/pkg/scalers/redis_scaler.go b/pkg/scalers/redis_scaler.go index 26b22c31bd3..908041d0f88 100644 --- a/pkg/scalers/redis_scaler.go +++ b/pkg/scalers/redis_scaler.go @@ -23,6 +23,7 @@ const ( type redisScaler struct { metadata *redisMetadata + client *redis.Client } type redisConnectionInfo struct { @@ -48,9 +49,21 @@ func NewRedisScaler(resolvedEnv, metadata, authParams map[string]string) (Scaler if err != nil { return nil, fmt.Errorf("error parsing redis metadata: %s", err) } + options := &redis.Options{ + Addr: meta.connectionInfo.address, + Password: meta.connectionInfo.password, + DB: meta.databaseIndex, + } + + if meta.connectionInfo.enableTLS == true { + options.TLSConfig = &tls.Config{ + InsecureSkipVerify: meta.connectionInfo.enableTLS, + } + } return &redisScaler{ metadata: meta, + client: redis.NewClient(options), }, nil } @@ -93,8 +106,7 @@ func parseRedisMetadata(metadata, resolvedEnv, authParams map[string]string) (*r // IsActive checks if there is any element in the Redis list func (s *redisScaler) IsActive(ctx context.Context) (bool, error) { - length, err := getRedisListLength( - ctx, s.metadata.connectionInfo.address, s.metadata.connectionInfo.password, s.metadata.listName, s.metadata.databaseIndex, s.metadata.connectionInfo.enableTLS) + length, err := getRedisListLength(ctx, s.client, s.metadata.listName) if err != nil { redisLog.Error(err, "error") @@ -105,6 +117,14 @@ func (s *redisScaler) IsActive(ctx context.Context) (bool, error) { } func (s *redisScaler) Close() error { + if s.client != nil { + err := s.client.Close() + if err != nil { + redisLog.Error(err, "error closing redis client") + return err + } + } + return nil } @@ -128,7 +148,7 @@ func (s *redisScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { // GetMetrics connects to Redis and finds the length of the list func (s *redisScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - listLen, err := getRedisListLength(ctx, s.metadata.connectionInfo.address, s.metadata.connectionInfo.password, s.metadata.listName, s.metadata.databaseIndex, s.metadata.connectionInfo.enableTLS) + listLen, err := getRedisListLength(ctx, s.client, s.metadata.listName) if err != nil { redisLog.Error(err, "error getting list length") @@ -144,20 +164,7 @@ func (s *redisScaler) GetMetrics(ctx context.Context, metricName string, metricS return append([]external_metrics.ExternalMetricValue{}, metric), nil } -func getRedisListLength(ctx context.Context, address string, password string, listName string, dbIndex int, enableTLS bool) (int64, error) { - options := &redis.Options{ - Addr: address, - Password: password, - DB: dbIndex, - } - - if enableTLS == true { - options.TLSConfig = &tls.Config{ - InsecureSkipVerify: enableTLS, - } - } - - client := redis.NewClient(options) +func getRedisListLength(ctx context.Context, client *redis.Client, listName string) (int64, error) { var listType *redis.StatusCmd listType = client.Type(listName) diff --git a/pkg/scalers/redis_scaler_test.go b/pkg/scalers/redis_scaler_test.go index 367539ec226..7b2bc0b4dd2 100644 --- a/pkg/scalers/redis_scaler_test.go +++ b/pkg/scalers/redis_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "github.com/go-redis/redis" "testing" ) @@ -71,7 +72,7 @@ func TestRedisGetMetricSpecForScaling(t *testing.T) { if err != nil { t.Fatal("Could not parse metadata:", err) } - mockRedisScaler := redisScaler{meta} + mockRedisScaler := redisScaler{meta, &redis.Client{}} metricSpec := mockRedisScaler.GetMetricSpecForScaling() metricName := metricSpec[0].External.Metric.Name