diff --git a/pkg/common/clients/redis_client.go b/pkg/common/clients/redis_client.go index a6ac38e..d85d683 100644 --- a/pkg/common/clients/redis_client.go +++ b/pkg/common/clients/redis_client.go @@ -21,7 +21,8 @@ type RedisClient interface { // RedisClient represents the Redis client. type redisClientImpl struct { - client *redis.Client + client *redis.Client + cluster *redis.ClusterClient } // NewRedisClient creates a new RedisClient (redis.Client) instance. @@ -66,9 +67,8 @@ func newRedisClusterClient(config RedisConfiguration, logger logging.Logger) Red } // redis client initialization return redisClientImpl{ - client: redis.NewFailoverClient(&redis.FailoverOptions{ - MasterName: config.MasterName, - SentinelAddrs: config.Address, + cluster: redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: config.Address, TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, RootCAs: certPool, @@ -76,7 +76,8 @@ func newRedisClusterClient(config RedisConfiguration, logger logging.Logger) Red cert, }, }, - })} + }), + } } func newRedisSingleNodeClient(config RedisConfiguration) RedisClient { @@ -86,22 +87,29 @@ func newRedisSingleNodeClient(config RedisConfiguration) RedisClient { })} } +func (c redisClientImpl) get() RedisClient { + if c.cluster != nil { + return c.cluster + } + return c.client +} + func (c redisClientImpl) Get(ctx context.Context, key string) *redis.StringCmd { - return c.client.Get(ctx, key) + return c.get().Get(ctx, key) } func (c redisClientImpl) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd { - return c.client.Set(ctx, key, value, expiration) + return c.get().Set(ctx, key, value, expiration) } func (c redisClientImpl) Publish(ctx context.Context, channel string, message interface{}) *redis.IntCmd { - return c.client.Publish(ctx, channel, message) + return c.get().Publish(ctx, channel, message) } func (c redisClientImpl) Subscribe(ctx context.Context, channels ...string) *redis.PubSub { - return c.client.Subscribe(ctx, channels...) + return c.get().Subscribe(ctx, channels...) } func (c redisClientImpl) Ping(ctx context.Context) *redis.StatusCmd { - return c.client.Ping(ctx) + return c.get().Ping(ctx) }